mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 04:22:50 +00:00
feat: support microphone whisper streaming (#1678)
* feat: support microphone whisper streaming * fix: cleanup print stmts and adjust how input is read * fix: remove incorrect comment * feat: split into new example and simplify * fix: feature flag example file * fix: fmt fixes * feat: simplify and remove redundant files
This commit is contained in:
137
candle-examples/examples/whisper-microphone/multilingual.rs
Normal file
137
candle-examples/examples/whisper-microphone/multilingual.rs
Normal file
@ -0,0 +1,137 @@
|
||||
use crate::{token_id, Model};
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use candle_transformers::models::whisper::{self as m};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const LANGUAGES: [(&str, &str); 99] = [
|
||||
("en", "english"),
|
||||
("zh", "chinese"),
|
||||
("de", "german"),
|
||||
("es", "spanish"),
|
||||
("ru", "russian"),
|
||||
("ko", "korean"),
|
||||
("fr", "french"),
|
||||
("ja", "japanese"),
|
||||
("pt", "portuguese"),
|
||||
("tr", "turkish"),
|
||||
("pl", "polish"),
|
||||
("ca", "catalan"),
|
||||
("nl", "dutch"),
|
||||
("ar", "arabic"),
|
||||
("sv", "swedish"),
|
||||
("it", "italian"),
|
||||
("id", "indonesian"),
|
||||
("hi", "hindi"),
|
||||
("fi", "finnish"),
|
||||
("vi", "vietnamese"),
|
||||
("he", "hebrew"),
|
||||
("uk", "ukrainian"),
|
||||
("el", "greek"),
|
||||
("ms", "malay"),
|
||||
("cs", "czech"),
|
||||
("ro", "romanian"),
|
||||
("da", "danish"),
|
||||
("hu", "hungarian"),
|
||||
("ta", "tamil"),
|
||||
("no", "norwegian"),
|
||||
("th", "thai"),
|
||||
("ur", "urdu"),
|
||||
("hr", "croatian"),
|
||||
("bg", "bulgarian"),
|
||||
("lt", "lithuanian"),
|
||||
("la", "latin"),
|
||||
("mi", "maori"),
|
||||
("ml", "malayalam"),
|
||||
("cy", "welsh"),
|
||||
("sk", "slovak"),
|
||||
("te", "telugu"),
|
||||
("fa", "persian"),
|
||||
("lv", "latvian"),
|
||||
("bn", "bengali"),
|
||||
("sr", "serbian"),
|
||||
("az", "azerbaijani"),
|
||||
("sl", "slovenian"),
|
||||
("kn", "kannada"),
|
||||
("et", "estonian"),
|
||||
("mk", "macedonian"),
|
||||
("br", "breton"),
|
||||
("eu", "basque"),
|
||||
("is", "icelandic"),
|
||||
("hy", "armenian"),
|
||||
("ne", "nepali"),
|
||||
("mn", "mongolian"),
|
||||
("bs", "bosnian"),
|
||||
("kk", "kazakh"),
|
||||
("sq", "albanian"),
|
||||
("sw", "swahili"),
|
||||
("gl", "galician"),
|
||||
("mr", "marathi"),
|
||||
("pa", "punjabi"),
|
||||
("si", "sinhala"),
|
||||
("km", "khmer"),
|
||||
("sn", "shona"),
|
||||
("yo", "yoruba"),
|
||||
("so", "somali"),
|
||||
("af", "afrikaans"),
|
||||
("oc", "occitan"),
|
||||
("ka", "georgian"),
|
||||
("be", "belarusian"),
|
||||
("tg", "tajik"),
|
||||
("sd", "sindhi"),
|
||||
("gu", "gujarati"),
|
||||
("am", "amharic"),
|
||||
("yi", "yiddish"),
|
||||
("lo", "lao"),
|
||||
("uz", "uzbek"),
|
||||
("fo", "faroese"),
|
||||
("ht", "haitian creole"),
|
||||
("ps", "pashto"),
|
||||
("tk", "turkmen"),
|
||||
("nn", "nynorsk"),
|
||||
("mt", "maltese"),
|
||||
("sa", "sanskrit"),
|
||||
("lb", "luxembourgish"),
|
||||
("my", "myanmar"),
|
||||
("bo", "tibetan"),
|
||||
("tl", "tagalog"),
|
||||
("mg", "malagasy"),
|
||||
("as", "assamese"),
|
||||
("tt", "tatar"),
|
||||
("haw", "hawaiian"),
|
||||
("ln", "lingala"),
|
||||
("ha", "hausa"),
|
||||
("ba", "bashkir"),
|
||||
("jw", "javanese"),
|
||||
("su", "sundanese"),
|
||||
];
|
||||
|
||||
/// Returns the token id for the selected language.
|
||||
pub fn detect_language(model: &mut Model, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
|
||||
let (_bsize, _, seq_len) = mel.dims3()?;
|
||||
let mel = mel.narrow(
|
||||
2,
|
||||
0,
|
||||
usize::min(seq_len, model.config().max_source_positions),
|
||||
)?;
|
||||
let device = mel.device();
|
||||
let language_token_ids = LANGUAGES
|
||||
.iter()
|
||||
.map(|(t, _)| token_id(tokenizer, &format!("<|{t}|>")))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let sot_token = token_id(tokenizer, m::SOT_TOKEN)?;
|
||||
let audio_features = model.encoder_forward(&mel, true)?;
|
||||
let tokens = Tensor::new(&[[sot_token]], device)?;
|
||||
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
||||
let ys = model.decoder_forward(&tokens, &audio_features, true)?;
|
||||
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
let logits = logits.index_select(&language_token_ids, 0)?;
|
||||
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
||||
let probs = probs.to_vec1::<f32>()?;
|
||||
let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::<Vec<_>>();
|
||||
probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for ((_, language), p) in probs.iter().take(5) {
|
||||
println!("{language}: {p}")
|
||||
}
|
||||
let language = token_id(tokenizer, &format!("<|{}|>", probs[0].0 .0))?;
|
||||
Ok(language)
|
||||
}
|
Reference in New Issue
Block a user