Add a KV cache to whisper. (#426)

This commit is contained in:
Laurent Mazare
2023-08-12 22:17:08 +02:00
committed by GitHub
parent a0908d212c
commit 60cd1551ca
3 changed files with 63 additions and 24 deletions

View File

@ -109,8 +109,8 @@ impl Decoder {
}
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
let model = &self.model;
let audio_features = model.encoder.forward(mel)?;
let model = &mut self.model;
let audio_features = model.encoder.forward(mel, true)?;
println!("audio features: {:?}", audio_features.dims());
let sample_len = model.config.max_target_positions / 2;
let mut sum_logprob = 0f64;
@ -126,7 +126,7 @@ impl Decoder {
// The model expects a batch dim but this inference loop does not handle
// it so we add it at this point.
let tokens_t = tokens_t.unsqueeze(0)?;
let logits = model.decoder.forward(&tokens_t, &audio_features)?;
let logits = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
let logits = logits.squeeze(0)?;
// Extract the no speech probability on the first iteration by looking at the first
@ -393,10 +393,10 @@ fn main() -> Result<()> {
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let model = Whisper::load(&vb, config)?;
let mut model = Whisper::load(&vb, config)?;
let language_token = match (args.model.is_multilingual(), args.language) {
(true, None) => Some(multilingual::detect_language(&model, &tokenizer, &mel)?),
(true, None) => Some(multilingual::detect_language(&mut model, &tokenizer, &mel)?),
(false, None) => None,
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
Ok(token_id) => Some(token_id),