diff --git a/candle-examples/examples/whisper/multilingual.rs b/candle-examples/examples/whisper/multilingual.rs index 01722a68..3587b01a 100644 --- a/candle-examples/examples/whisper/multilingual.rs +++ b/candle-examples/examples/whisper/multilingual.rs @@ -106,13 +106,15 @@ const LANGUAGES: [(&str, &str); 99] = [ /// Returns the token id for the selected language. pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result { + let (_bsize, _, seq_len) = mel.dims3()?; + let mel = mel.narrow(2, 0, usize::min(seq_len, model.config.max_source_positions))?; let device = mel.device(); let language_token_ids = LANGUAGES .iter() .map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>"))) .collect::>>()?; let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?; - let audio_features = model.encoder.forward(mel, true)?; + let audio_features = model.encoder.forward(&mel, true)?; let tokens = Tensor::new(&[[sot_token]], device)?; let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?; let logits = model