From 19042962d5ae3ab17866522a0d2d99e873624441 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 1 Sep 2023 21:04:07 +0200 Subject: [PATCH] Whisper fix (#711) * Remove unnecessary file. * Whisper fix. --- candle-examples/examples/whisper/main.rs | 5 +---- candle-examples/examples/whisper/multilingual.rs | 7 ++----- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index fc64d458..5dd8ee20 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -146,11 +146,8 @@ impl Decoder { tokens.push(language_token); } match self.task { - Some(Task::Transcribe) => tokens.push(self.transcribe_token), + None | Some(Task::Transcribe) => tokens.push(self.transcribe_token), Some(Task::Translate) => tokens.push(self.translate_token), - None => { - // Nothing in this case, same as the Python implementation. - } } if !self.timestamps { tokens.push(self.no_timestamps_token); diff --git a/candle-examples/examples/whisper/multilingual.rs b/candle-examples/examples/whisper/multilingual.rs index 3587b01a..bc0bae1f 100644 --- a/candle-examples/examples/whisper/multilingual.rs +++ b/candle-examples/examples/whisper/multilingual.rs @@ -117,11 +117,8 @@ pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) let audio_features = model.encoder.forward(&mel, true)?; let tokens = Tensor::new(&[[sot_token]], device)?; let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?; - let logits = model - .decoder - .forward(&tokens, &audio_features, true)? - .i(0)? - .i(0)?; + let ys = model.decoder.forward(&tokens, &audio_features, true)?; + let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?; let logits = logits.index_select(&language_token_ids, 0)?; let probs = candle_nn::ops::softmax(&logits, D::Minus1)?; let probs = probs.to_vec1::()?;