mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
@ -146,11 +146,8 @@ impl Decoder {
|
|||||||
tokens.push(language_token);
|
tokens.push(language_token);
|
||||||
}
|
}
|
||||||
match self.task {
|
match self.task {
|
||||||
Some(Task::Transcribe) => tokens.push(self.transcribe_token),
|
None | Some(Task::Transcribe) => tokens.push(self.transcribe_token),
|
||||||
Some(Task::Translate) => tokens.push(self.translate_token),
|
Some(Task::Translate) => tokens.push(self.translate_token),
|
||||||
None => {
|
|
||||||
// Nothing in this case, same as the Python implementation.
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if !self.timestamps {
|
if !self.timestamps {
|
||||||
tokens.push(self.no_timestamps_token);
|
tokens.push(self.no_timestamps_token);
|
||||||
|
@ -117,11 +117,8 @@ pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor)
|
|||||||
let audio_features = model.encoder.forward(&mel, true)?;
|
let audio_features = model.encoder.forward(&mel, true)?;
|
||||||
let tokens = Tensor::new(&[[sot_token]], device)?;
|
let tokens = Tensor::new(&[[sot_token]], device)?;
|
||||||
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
||||||
let logits = model
|
let ys = model.decoder.forward(&tokens, &audio_features, true)?;
|
||||||
.decoder
|
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||||
.forward(&tokens, &audio_features, true)?
|
|
||||||
.i(0)?
|
|
||||||
.i(0)?;
|
|
||||||
let logits = logits.index_select(&language_token_ids, 0)?;
|
let logits = logits.index_select(&language_token_ids, 0)?;
|
||||||
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
||||||
let probs = probs.to_vec1::<f32>()?;
|
let probs = probs.to_vec1::<f32>()?;
|
||||||
|
Reference in New Issue
Block a user