diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 4788385b..f0d7cf47 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -70,6 +70,7 @@ struct Decoder { rng: rand::rngs::StdRng, task: Option, timestamps: bool, + verbose: bool, tokenizer: Tokenizer, suppress_tokens: Tensor, sot_token: u32, @@ -82,6 +83,7 @@ struct Decoder { } impl Decoder { + #[allow(clippy::too_many_arguments)] fn new( model: Whisper, tokenizer: Tokenizer, @@ -90,10 +92,16 @@ impl Decoder { language_token: Option, task: Option, timestamps: bool, + verbose: bool, ) -> Result { + let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?; + // Suppress the notimestamps token when in timestamps mode. + // https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452 let suppress_tokens: Vec = (0..model.config.vocab_size as u32) .map(|i| { - if model.config.suppress_tokens.contains(&i) { + if model.config.suppress_tokens.contains(&i) + || timestamps && i == no_timestamps_token + { f32::NEG_INFINITY } else { 0f32 @@ -104,7 +112,6 @@ impl Decoder { let sot_token = token_id(&tokenizer, SOT_TOKEN)?; let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?; let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?; - let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?; let eot_token = token_id(&tokenizer, EOT_TOKEN)?; let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?; Ok(Self { @@ -113,6 +120,7 @@ impl Decoder { tokenizer, task, timestamps, + verbose, suppress_tokens, sot_token, transcribe_token, @@ -127,7 +135,9 @@ impl Decoder { fn decode(&mut self, mel: &Tensor, t: f64) -> Result { let model = &mut self.model; let audio_features = model.encoder.forward(mel, true)?; - println!("audio features: {:?}", audio_features.dims()); + if self.verbose { + println!("audio features: {:?}", audio_features.dims()); + } let sample_len = model.config.max_target_positions / 2; let mut sum_logprob = 0f64; let mut no_speech_prob = f64::NAN; @@ -168,6 +178,13 @@ impl Decoder { .final_linear(&ys.i((..1, seq_len - 1..))?)? .i(0)? .i(0)?; + // TODO: Besides suppress tokens, we should apply the heuristics from + // ApplyTimestampRules, i.e.: + // - Timestamps come in pairs, except before EOT. + // - Timestamps should be non-decreasing. + // - If the sum of the probabilities of timestamps is higher than any other tokens, + // only consider timestamps when sampling. + // https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439 let logits = logits.broadcast_add(&self.suppress_tokens)?; let next_token = if t > 0f64 { let prs = softmax(&(&logits / t)?, 0)?; @@ -249,7 +266,55 @@ impl Decoder { duration: segment_duration, dr, }; - println!("{seek}: {segment:?}, in {:?}", start.elapsed()); + if self.timestamps { + println!( + "{:.1}s -- {:.1}s", + segment.start, + segment.start + segment.duration, + ); + let mut tokens_to_decode = vec![]; + let mut prev_timestamp_s = 0f32; + for &token in segment.dr.tokens.iter() { + if token == self.sot_token || token == self.eot_token { + continue; + } + // The no_timestamp_token is the last before the timestamp ones. + if token > self.no_timestamps_token { + let timestamp_s = (token - self.no_timestamps_token + 1) as f32 / 50.; + if !tokens_to_decode.is_empty() { + let text = self + .tokenizer + .decode(&tokens_to_decode, true) + .map_err(E::msg)?; + println!(" {:.1}s-{:.1}s: {}", prev_timestamp_s, timestamp_s, text); + tokens_to_decode.clear() + } + prev_timestamp_s = timestamp_s; + } else { + tokens_to_decode.push(token) + } + } + if !tokens_to_decode.is_empty() { + let text = self + .tokenizer + .decode(&tokens_to_decode, true) + .map_err(E::msg)?; + if !text.is_empty() { + println!(" {:.1}s-...: {}", prev_timestamp_s, text); + } + tokens_to_decode.clear() + } + } else { + println!( + "{:.1}s -- {:.1}s: {}", + segment.start, + segment.start + segment.duration, + segment.dr.text, + ) + } + if self.verbose { + println!("{seek}: {segment:?}, in {:?}", start.elapsed()); + } segments.push(segment) } Ok(segments) @@ -357,6 +422,10 @@ struct Args { /// Timestamps mode, this is not fully implemented yet. #[arg(long)] timestamps: bool, + + /// Print the full DecodingResult structure rather than just the text. + #[arg(long)] + verbose: bool, } fn main() -> Result<()> { @@ -466,6 +535,7 @@ fn main() -> Result<()> { language_token, args.task, args.timestamps, + args.verbose, )?; dc.run(&mel)?; Ok(())