diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 3090ae8f..61db40bc 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -38,18 +38,9 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; // Tokenizer dependent bits. const SOT_TOKEN: &str = "<|startoftranscript|>"; +const TRANSCRIBE_TOKEN: &str = "<|transcribe|>"; const EOT_TOKEN: &str = "<|endoftext|>"; const NO_SPEECH_TOKEN: &str = "<|nocaptions|>"; -// From the _get_suppress_tokens function + 50362 (no timestamp) -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605 -const SUPPRESS_TOKENS: [u32; 91] = [ - 1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357, - 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782, - 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, - 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992, - 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549, - 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362, -]; #[allow(dead_code)] #[derive(Debug, Clone)] @@ -76,15 +67,23 @@ struct Decoder { tokenizer: Tokenizer, suppress_tokens: Tensor, sot_token: u32, + transcribe_token: u32, eot_token: u32, no_speech_token: u32, + language_token: Option, } impl Decoder { - fn new(model: Whisper, tokenizer: Tokenizer, seed: u64, device: &Device) -> Result { + fn new( + model: Whisper, + tokenizer: Tokenizer, + seed: u64, + device: &Device, + language_token: Option, + ) -> Result { let suppress_tokens: Vec = (0..model.config.vocab_size as u32) .map(|i| { - if SUPPRESS_TOKENS.contains(&i) { + if model.config.suppress_tokens.contains(&i) { f32::NEG_INFINITY } else { 0f32 @@ -93,6 +92,7 @@ impl Decoder { .collect(); let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; let sot_token = token_id(&tokenizer, SOT_TOKEN)?; + let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?; let eot_token = token_id(&tokenizer, EOT_TOKEN)?; let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?; Ok(Self { @@ -101,8 +101,10 @@ impl Decoder { tokenizer, suppress_tokens, sot_token, + transcribe_token, eot_token, no_speech_token, + language_token, }) } @@ -114,6 +116,10 @@ impl Decoder { let mut sum_logprob = 0f64; let mut no_speech_prob = f64::NAN; let mut tokens = vec![self.sot_token]; + if let Some(language_token) = self.language_token { + tokens.push(language_token) + } + tokens.push(self.transcribe_token); for i in 0..sample_len { let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?; @@ -236,23 +242,29 @@ pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result { enum WhichModel { Tiny, TinyEn, + Base, + BaseEn, SmallEn, MediumEn, + LargeV2, } impl WhichModel { fn is_multilingual(&self) -> bool { match self { - Self::Tiny => true, - Self::TinyEn | Self::SmallEn | Self::MediumEn => false, + Self::Tiny | Self::Base | Self::LargeV2 => true, + Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false, } } fn model_and_revision(&self) -> (&'static str, &'static str) { match self { Self::Tiny => ("openai/whisper-tiny", "main"), Self::TinyEn => ("openai/whisper-tiny.en", "refs/pr/15"), + Self::Base => ("openai/whisper-base", "refs/pr/22"), + Self::BaseEn => ("openai/whisper-base.en", "refs/pr/13"), Self::SmallEn => ("openai/whisper-small.en", "refs/pr/10"), Self::MediumEn => ("openai/whisper-medium.en", "refs/pr/11"), + Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"), } } } @@ -379,10 +391,12 @@ fn main() -> Result<()> { let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; let model = Whisper::load(&vb, config)?; - if args.model.is_multilingual() { - multilingual::detect_language(&model, &tokenizer, &mel)? - } - let mut dc = Decoder::new(model, tokenizer, args.seed, &device)?; + let language_token = if args.model.is_multilingual() { + Some(multilingual::detect_language(&model, &tokenizer, &mel)?) + } else { + None + }; + let mut dc = Decoder::new(model, tokenizer, args.seed, &device, language_token)?; dc.run(&mel)?; Ok(()) } diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index 7015199d..c61882bc 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -16,11 +16,21 @@ pub struct Config { // pub n_text_state: usize, pub decoder_attention_heads: usize, // n_text_head pub decoder_layers: usize, // n_text_layer + pub suppress_tokens: Vec, } impl Config { #[allow(dead_code)] pub fn tiny_en() -> Self { + let suppress_tokens = vec![ + 1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, + 357, 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, + 1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, + 7203, 9609, 9959, 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, + 16410, 16791, 17992, 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, + 34949, 40283, 40493, 40549, 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, + 50362, + ]; Self { num_mel_bins: 80, vocab_size: 51864, @@ -32,6 +42,7 @@ impl Config { // n_text_state: 384, decoder_attention_heads: 6, decoder_layers: 4, + suppress_tokens, } } } diff --git a/candle-examples/examples/whisper/multilingual.rs b/candle-examples/examples/whisper/multilingual.rs index 5436cbd3..1342ad55 100644 --- a/candle-examples/examples/whisper/multilingual.rs +++ b/candle-examples/examples/whisper/multilingual.rs @@ -104,7 +104,8 @@ const LANGUAGES: [(&str, &str); 99] = [ ("su", "sundanese"), ]; -pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<()> { +/// Returns the token id for the selected language. +pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result { let device = mel.device(); let language_token_ids = LANGUAGES .iter() @@ -114,14 +115,11 @@ pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> let audio_features = model.encoder.forward(mel)?; let tokens = Tensor::new(&[[sot_token]], device)?; let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?; - println!("{tokens}"); - println!("{audio_features}"); let logits = model .decoder .forward(&tokens, &audio_features)? .i(0)? .i(0)?; - println!("{logits}"); let logits = logits.index_select(&language_token_ids, 0)?; let probs = candle_nn::ops::softmax(&logits, D::Minus1)?; let probs = probs.to_vec1::()?; @@ -130,5 +128,6 @@ pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> for ((_, language), p) in probs.iter().take(5) { println!("{language}: {p}") } - Ok(()) + let language = crate::token_id(tokenizer, &format!("<|{}|>", probs[0].0 .0))?; + Ok(language) }