Whisper fix (#711)

* Remove unnecessary file.

* Whisper fix.
This commit is contained in:
Laurent Mazare
2023-09-01 21:04:07 +02:00
committed by GitHub
parent 731e3ffb03
commit 19042962d5
2 changed files with 3 additions and 9 deletions

View File

@ -146,11 +146,8 @@ impl Decoder {
tokens.push(language_token);
}
match self.task {
Some(Task::Transcribe) => tokens.push(self.transcribe_token),
None | Some(Task::Transcribe) => tokens.push(self.transcribe_token),
Some(Task::Translate) => tokens.push(self.translate_token),
None => {
// Nothing in this case, same as the Python implementation.
}
}
if !self.timestamps {
tokens.push(self.no_timestamps_token);

View File

@ -117,11 +117,8 @@ pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor)
let audio_features = model.encoder.forward(&mel, true)?;
let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
let logits = model
.decoder
.forward(&tokens, &audio_features, true)?
.i(0)?
.i(0)?;
let ys = model.decoder.forward(&tokens, &audio_features, true)?;
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
let logits = logits.index_select(&language_token_ids, 0)?;
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
let probs = probs.to_vec1::<f32>()?;