diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index d5f91053..3090ae8f 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -1,7 +1,6 @@ // https://github.com/openai/whisper/blob/main/whisper/model.py // TODO: // - kv-cache support? -// - Language detection? // - Batch size greater than 1. // - More token filters (SuppressBlanks, ApplyTimestampRules). @@ -19,6 +18,7 @@ use tokenizers::Tokenizer; mod audio; mod model; use model::{Config, Whisper}; +mod multilingual; const DTYPE: DType = DType::F32; @@ -37,9 +37,9 @@ const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]; const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; // Tokenizer dependent bits. -const SOT_TOKEN: u32 = 50257; -const EOT_TOKEN: u32 = 50256; -const NO_SPEECH_TOKEN: u32 = 50361; +const SOT_TOKEN: &str = "<|startoftranscript|>"; +const EOT_TOKEN: &str = "<|endoftext|>"; +const NO_SPEECH_TOKEN: &str = "<|nocaptions|>"; // From the _get_suppress_tokens function + 50362 (no timestamp) // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605 const SUPPRESS_TOKENS: [u32; 91] = [ @@ -75,6 +75,9 @@ struct Decoder { rng: rand::rngs::StdRng, tokenizer: Tokenizer, suppress_tokens: Tensor, + sot_token: u32, + eot_token: u32, + no_speech_token: u32, } impl Decoder { @@ -89,11 +92,17 @@ impl Decoder { }) .collect(); let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; + let sot_token = token_id(&tokenizer, SOT_TOKEN)?; + let eot_token = token_id(&tokenizer, EOT_TOKEN)?; + let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?; Ok(Self { model, rng: rand::rngs::StdRng::seed_from_u64(seed), tokenizer, suppress_tokens, + sot_token, + eot_token, + no_speech_token, }) } @@ -104,7 +113,7 @@ impl Decoder { let sample_len = model.config.max_target_positions / 2; let mut sum_logprob = 0f64; let mut no_speech_prob = f64::NAN; - let mut tokens = vec![SOT_TOKEN]; + let mut tokens = vec![self.sot_token]; for i in 0..sample_len { let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?; @@ -118,7 +127,7 @@ impl Decoder { // token logits and the probability for the according token. if i == 0 { no_speech_prob = softmax(&logits.get(0)?, 0)? - .get(NO_SPEECH_TOKEN as usize)? + .get(self.no_speech_token as usize)? .to_scalar::()? as f64; } @@ -144,7 +153,7 @@ impl Decoder { let prob = softmax(&logits, candle::D::Minus1)? .get(next_token as usize)? .to_scalar::()? as f64; - if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions { + if next_token == self.eot_token || tokens.len() > model.config.max_target_positions { break; } sum_logprob += prob.ln(); @@ -216,19 +225,34 @@ impl Decoder { } } +pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result { + match tokenizer.token_to_id(token) { + None => candle::bail!("no token-id for {token}"), + Some(id) => Ok(id), + } +} + #[derive(Clone, Copy, Debug, ValueEnum)] enum WhichModel { Tiny, - Small, - Medium, + TinyEn, + SmallEn, + MediumEn, } impl WhichModel { + fn is_multilingual(&self) -> bool { + match self { + Self::Tiny => true, + Self::TinyEn | Self::SmallEn | Self::MediumEn => false, + } + } fn model_and_revision(&self) -> (&'static str, &'static str) { match self { - Self::Tiny => ("openai/whisper-tiny.en", "refs/pr/15"), - Self::Small => ("openai/whisper-small.en", "refs/pr/10"), - Self::Medium => ("openai/whisper-medium.en", "refs/pr/11"), + Self::Tiny => ("openai/whisper-tiny", "main"), + Self::TinyEn => ("openai/whisper-tiny.en", "refs/pr/15"), + Self::SmallEn => ("openai/whisper-small.en", "refs/pr/10"), + Self::MediumEn => ("openai/whisper-medium.en", "refs/pr/11"), } } } @@ -249,7 +273,7 @@ struct Args { revision: Option, /// The model to be used, can be tiny, small, medium. - #[arg(long, default_value = "tiny")] + #[arg(long, default_value = "tiny-en")] model: WhichModel, /// The input to be processed, in wav format, will default to `jfk.wav`. Alternatively @@ -354,6 +378,10 @@ fn main() -> Result<()> { let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; let model = Whisper::load(&vb, config)?; + + if args.model.is_multilingual() { + multilingual::detect_language(&model, &tokenizer, &mel)? + } let mut dc = Decoder::new(model, tokenizer, args.seed, &device)?; dc.run(&mel)?; Ok(()) diff --git a/candle-examples/examples/whisper/multilingual.rs b/candle-examples/examples/whisper/multilingual.rs new file mode 100644 index 00000000..5436cbd3 --- /dev/null +++ b/candle-examples/examples/whisper/multilingual.rs @@ -0,0 +1,134 @@ +use crate::Whisper; +use candle::{IndexOp, Result, Tensor, D}; +use tokenizers::Tokenizer; + +const LANGUAGES: [(&str, &str); 99] = [ + ("en", "english"), + ("zh", "chinese"), + ("de", "german"), + ("es", "spanish"), + ("ru", "russian"), + ("ko", "korean"), + ("fr", "french"), + ("ja", "japanese"), + ("pt", "portuguese"), + ("tr", "turkish"), + ("pl", "polish"), + ("ca", "catalan"), + ("nl", "dutch"), + ("ar", "arabic"), + ("sv", "swedish"), + ("it", "italian"), + ("id", "indonesian"), + ("hi", "hindi"), + ("fi", "finnish"), + ("vi", "vietnamese"), + ("he", "hebrew"), + ("uk", "ukrainian"), + ("el", "greek"), + ("ms", "malay"), + ("cs", "czech"), + ("ro", "romanian"), + ("da", "danish"), + ("hu", "hungarian"), + ("ta", "tamil"), + ("no", "norwegian"), + ("th", "thai"), + ("ur", "urdu"), + ("hr", "croatian"), + ("bg", "bulgarian"), + ("lt", "lithuanian"), + ("la", "latin"), + ("mi", "maori"), + ("ml", "malayalam"), + ("cy", "welsh"), + ("sk", "slovak"), + ("te", "telugu"), + ("fa", "persian"), + ("lv", "latvian"), + ("bn", "bengali"), + ("sr", "serbian"), + ("az", "azerbaijani"), + ("sl", "slovenian"), + ("kn", "kannada"), + ("et", "estonian"), + ("mk", "macedonian"), + ("br", "breton"), + ("eu", "basque"), + ("is", "icelandic"), + ("hy", "armenian"), + ("ne", "nepali"), + ("mn", "mongolian"), + ("bs", "bosnian"), + ("kk", "kazakh"), + ("sq", "albanian"), + ("sw", "swahili"), + ("gl", "galician"), + ("mr", "marathi"), + ("pa", "punjabi"), + ("si", "sinhala"), + ("km", "khmer"), + ("sn", "shona"), + ("yo", "yoruba"), + ("so", "somali"), + ("af", "afrikaans"), + ("oc", "occitan"), + ("ka", "georgian"), + ("be", "belarusian"), + ("tg", "tajik"), + ("sd", "sindhi"), + ("gu", "gujarati"), + ("am", "amharic"), + ("yi", "yiddish"), + ("lo", "lao"), + ("uz", "uzbek"), + ("fo", "faroese"), + ("ht", "haitian creole"), + ("ps", "pashto"), + ("tk", "turkmen"), + ("nn", "nynorsk"), + ("mt", "maltese"), + ("sa", "sanskrit"), + ("lb", "luxembourgish"), + ("my", "myanmar"), + ("bo", "tibetan"), + ("tl", "tagalog"), + ("mg", "malagasy"), + ("as", "assamese"), + ("tt", "tatar"), + ("haw", "hawaiian"), + ("ln", "lingala"), + ("ha", "hausa"), + ("ba", "bashkir"), + ("jw", "javanese"), + ("su", "sundanese"), +]; + +pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<()> { + let device = mel.device(); + let language_token_ids = LANGUAGES + .iter() + .map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>"))) + .collect::>>()?; + let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?; + let audio_features = model.encoder.forward(mel)?; + let tokens = Tensor::new(&[[sot_token]], device)?; + let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?; + println!("{tokens}"); + println!("{audio_features}"); + let logits = model + .decoder + .forward(&tokens, &audio_features)? + .i(0)? + .i(0)?; + println!("{logits}"); + let logits = logits.index_select(&language_token_ids, 0)?; + let probs = candle_nn::ops::softmax(&logits, D::Minus1)?; + let probs = probs.to_vec1::()?; + let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::>(); + probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for ((_, language), p) in probs.iter().take(5) { + println!("{language}: {p}") + } + Ok(()) +}