mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
@ -146,11 +146,8 @@ impl Decoder {
|
||||
tokens.push(language_token);
|
||||
}
|
||||
match self.task {
|
||||
Some(Task::Transcribe) => tokens.push(self.transcribe_token),
|
||||
None | Some(Task::Transcribe) => tokens.push(self.transcribe_token),
|
||||
Some(Task::Translate) => tokens.push(self.translate_token),
|
||||
None => {
|
||||
// Nothing in this case, same as the Python implementation.
|
||||
}
|
||||
}
|
||||
if !self.timestamps {
|
||||
tokens.push(self.no_timestamps_token);
|
||||
|
@ -117,11 +117,8 @@ pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor)
|
||||
let audio_features = model.encoder.forward(&mel, true)?;
|
||||
let tokens = Tensor::new(&[[sot_token]], device)?;
|
||||
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
||||
let logits = model
|
||||
.decoder
|
||||
.forward(&tokens, &audio_features, true)?
|
||||
.i(0)?
|
||||
.i(0)?;
|
||||
let ys = model.decoder.forward(&tokens, &audio_features, true)?;
|
||||
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
let logits = logits.index_select(&language_token_ids, 0)?;
|
||||
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
||||
let probs = probs.to_vec1::<f32>()?;
|
||||
|
Reference in New Issue
Block a user