mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
More multilingual support for whisper. (#419)
* More multilingual support for whisper. * Use the language token appropriately.
This commit is contained in:
@ -104,7 +104,8 @@ const LANGUAGES: [(&str, &str); 99] = [
|
||||
("su", "sundanese"),
|
||||
];
|
||||
|
||||
pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<()> {
|
||||
/// Returns the token id for the selected language.
|
||||
pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
|
||||
let device = mel.device();
|
||||
let language_token_ids = LANGUAGES
|
||||
.iter()
|
||||
@ -114,14 +115,11 @@ pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) ->
|
||||
let audio_features = model.encoder.forward(mel)?;
|
||||
let tokens = Tensor::new(&[[sot_token]], device)?;
|
||||
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
||||
println!("{tokens}");
|
||||
println!("{audio_features}");
|
||||
let logits = model
|
||||
.decoder
|
||||
.forward(&tokens, &audio_features)?
|
||||
.i(0)?
|
||||
.i(0)?;
|
||||
println!("{logits}");
|
||||
let logits = logits.index_select(&language_token_ids, 0)?;
|
||||
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
||||
let probs = probs.to_vec1::<f32>()?;
|
||||
@ -130,5 +128,6 @@ pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) ->
|
||||
for ((_, language), p) in probs.iter().take(5) {
|
||||
println!("{language}: {p}")
|
||||
}
|
||||
Ok(())
|
||||
let language = crate::token_id(tokenizer, &format!("<|{}|>", probs[0].0 .0))?;
|
||||
Ok(language)
|
||||
}
|
||||
|
Reference in New Issue
Block a user