More multilingual support for whisper. (#419)

* More multilingual support for whisper.

* Use the language token appropriately.
This commit is contained in:
Laurent Mazare
2023-08-12 16:32:52 +02:00
committed by GitHub
parent 0c3f109faa
commit 0741ebbd51
3 changed files with 47 additions and 23 deletions

View File

@ -104,7 +104,8 @@ const LANGUAGES: [(&str, &str); 99] = [
("su", "sundanese"),
];
pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<()> {
/// Returns the token id for the selected language.
pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
let device = mel.device();
let language_token_ids = LANGUAGES
.iter()
@ -114,14 +115,11 @@ pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) ->
let audio_features = model.encoder.forward(mel)?;
let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
println!("{tokens}");
println!("{audio_features}");
let logits = model
.decoder
.forward(&tokens, &audio_features)?
.i(0)?
.i(0)?;
println!("{logits}");
let logits = logits.index_select(&language_token_ids, 0)?;
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
let probs = probs.to_vec1::<f32>()?;
@ -130,5 +128,6 @@ pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) ->
for ((_, language), p) in probs.iter().take(5) {
println!("{language}: {p}")
}
Ok(())
let language = crate::token_id(tokenizer, &format!("<|{}|>", probs[0].0 .0))?;
Ok(language)
}