Optimize the logit computations in the whisper example. (#434)

This commit is contained in:
Laurent Mazare
2023-08-13 23:00:13 +02:00
committed by GitHub
parent d379a76a9e
commit 8bd2b22b33
2 changed files with 26 additions and 29 deletions

View File

@ -10,7 +10,7 @@ extern crate accelerate_src;
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle::{DType, Device, IndexOp, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
@ -128,21 +128,24 @@ impl Decoder {
// The model expects a batch dim but this inference loop does not handle
// it so we add it at this point.
let tokens_t = tokens_t.unsqueeze(0)?;
let logits = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
let logits = logits.squeeze(0)?;
let ys = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
no_speech_prob = softmax(&logits.get(0)?, 0)?
.get(self.no_speech_token as usize)?
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
no_speech_prob = softmax(&logits, 0)?
.i(self.no_speech_token as usize)?
.to_scalar::<f32>()? as f64;
}
let (seq_len, _) = logits.dims2()?;
let logits = logits
.get(seq_len - 1)?
.broadcast_add(&self.suppress_tokens)?;
let (_, seq_len, _) = ys.dims3()?;
let logits = model
.decoder
.final_linear(&ys.i((..1, seq_len - 1..))?)?
.i(0)?
.i(0)?;
let logits = logits.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
@ -159,7 +162,7 @@ impl Decoder {
};
tokens.push(next_token);
let prob = softmax(&logits, candle::D::Minus1)?
.get(next_token as usize)?
.i(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
break;

View File

@ -364,11 +364,13 @@ pub struct TextDecoder {
ln: LayerNorm,
mask: Tensor,
span: tracing::Span,
span_final: tracing::Span,
}
impl TextDecoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final");
let n_state = cfg.d_model;
let n_head = cfg.decoder_attention_heads;
let n_ctx = cfg.max_target_positions;
@ -384,7 +386,6 @@ impl TextDecoder {
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
.collect();
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
Ok(Self {
token_embedding,
positional_embedding,
@ -392,6 +393,7 @@ impl TextDecoder {
ln,
mask,
span,
span_final,
})
}
@ -405,12 +407,16 @@ impl TextDecoder {
for block in self.blocks.iter_mut() {
x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
}
let x = self.ln.forward(&x)?;
let w = self
.token_embedding
.embeddings()
.broadcast_left(x_dims[0])?;
let logits = x.matmul(&w.t()?)?;
self.ln.forward(&x)
}
pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> {
let b_size = x.dim(0)?;
let w = self.token_embedding.embeddings().broadcast_left(b_size)?;
let logits = {
let _enter = self.span_final.enter();
x.matmul(&w.t()?)?
};
Ok(logits)
}
}
@ -432,16 +438,4 @@ impl Whisper {
config,
})
}
#[allow(dead_code)]
pub fn forward(
&mut self,
mel: &Tensor,
tokens: &Tensor,
flush_kv_cache: bool,
) -> Result<Tensor> {
let enc = self.encoder.forward(mel, flush_kv_cache)?;
let dec = self.decoder.forward(tokens, &enc, flush_kv_cache)?;
Ok(dec)
}
}