Support longer sequences in language detection. (#428)

This commit is contained in:
Laurent Mazare
2023-08-13 14:16:15 +02:00
committed by GitHub
parent 9aca398a4f
commit 6d694554b8

View File

@ -106,13 +106,15 @@ const LANGUAGES: [(&str, &str); 99] = [
/// Returns the token id for the selected language.
pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
let (_bsize, _, seq_len) = mel.dims3()?;
let mel = mel.narrow(2, 0, usize::min(seq_len, model.config.max_source_positions))?;
let device = mel.device();
let language_token_ids = LANGUAGES
.iter()
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
.collect::<Result<Vec<_>>>()?;
let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?;
let audio_features = model.encoder.forward(mel, true)?;
let audio_features = model.encoder.forward(&mel, true)?;
let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
let logits = model