From 3164cd24fa56db2056191ff5c06d2cae16a34d05 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 17 Aug 2023 16:59:36 +0100 Subject: [PATCH] =?UTF-8?q?Replicate=20the=20sot-token=20logic=20from=20th?= =?UTF-8?q?e=20Python=20implementation=20more=20acc=E2=80=A6=20(#491)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Replicate the sot-token logic from the Python implementation more accurately. * Add a flag to control the timestamp mode. --- candle-examples/examples/whisper/main.rs | 54 ++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 9f8810a7..4ea60fb4 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -41,6 +41,8 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; // Tokenizer dependent bits. const SOT_TOKEN: &str = "<|startoftranscript|>"; const TRANSCRIBE_TOKEN: &str = "<|transcribe|>"; +const TRANSLATE_TOKEN: &str = "<|translate|>"; +const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>"; const EOT_TOKEN: &str = "<|endoftext|>"; const NO_SPEECH_TOKEN: &str = "<|nocaptions|>"; @@ -66,12 +68,16 @@ struct Segment { struct Decoder { model: Whisper, rng: rand::rngs::StdRng, + task: Option, + timestamps: bool, tokenizer: Tokenizer, suppress_tokens: Tensor, sot_token: u32, transcribe_token: u32, + translate_token: u32, eot_token: u32, no_speech_token: u32, + no_timestamps_token: u32, language_token: Option, } @@ -82,6 +88,8 @@ impl Decoder { seed: u64, device: &Device, language_token: Option, + task: Option, + timestamps: bool, ) -> Result { let suppress_tokens: Vec = (0..model.config.vocab_size as u32) .map(|i| { @@ -95,18 +103,24 @@ impl Decoder { let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; let sot_token = token_id(&tokenizer, SOT_TOKEN)?; let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?; + let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?; + let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?; let eot_token = token_id(&tokenizer, EOT_TOKEN)?; let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?; Ok(Self { model, rng: rand::rngs::StdRng::seed_from_u64(seed), tokenizer, + task, + timestamps, suppress_tokens, sot_token, transcribe_token, + translate_token, eot_token, no_speech_token, language_token, + no_timestamps_token, }) } @@ -118,10 +132,19 @@ impl Decoder { let mut sum_logprob = 0f64; let mut no_speech_prob = f64::NAN; let mut tokens = vec![self.sot_token]; - if let Some(language_token) = self.language_token { - tokens.push(language_token) + match self.task { + Some(Task::Transcribe) => tokens.push(self.transcribe_token), + Some(Task::Translate) => tokens.push(self.translate_token), + None => { + // Nothing in this case, same as the Python implementation. + } + } + if let Some(language_token) = self.language_token { + tokens.push(language_token); + } + if !self.timestamps { + tokens.push(self.no_timestamps_token); } - tokens.push(self.transcribe_token); for i in 0..sample_len { let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?; @@ -240,6 +263,12 @@ pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result { } } +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Task { + Transcribe, + Translate, +} + #[derive(Clone, Copy, Debug, ValueEnum)] enum WhichModel { Tiny, @@ -313,6 +342,15 @@ struct Args { /// Language. #[arg(long)] language: Option, + + /// Task, when no task is specified, the input tokens contain only the sot token which can + /// improve things when in no-timestamp mode. + #[arg(long)] + task: Option, + + /// Timestamps mode, this is not fully implemented yet. + #[arg(long)] + timestamps: bool, } fn main() -> Result<()> { @@ -414,7 +452,15 @@ fn main() -> Result<()> { anyhow::bail!("a language cannot be set for non-multilingual models") } }; - let mut dc = Decoder::new(model, tokenizer, args.seed, &device, language_token)?; + let mut dc = Decoder::new( + model, + tokenizer, + args.seed, + &device, + language_token, + args.task, + args.timestamps, + )?; dc.run(&mel)?; Ok(()) }