From 8bd2b22b33a57765e20e40eba1826c89f6cfb26a Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 13 Aug 2023 23:00:13 +0200 Subject: [PATCH] Optimize the logit computations in the whisper example. (#434) --- candle-examples/examples/whisper/main.rs | 23 +++++++++------- candle-examples/examples/whisper/model.rs | 32 +++++++++-------------- 2 files changed, 26 insertions(+), 29 deletions(-) diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index bc12692d..99919f8d 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -10,7 +10,7 @@ extern crate accelerate_src; extern crate intel_mkl_src; use anyhow::{Error as E, Result}; -use candle::{DType, Device, Tensor}; +use candle::{DType, Device, IndexOp, Tensor}; use candle_nn::{ops::softmax, VarBuilder}; use clap::{Parser, ValueEnum}; use hf_hub::{api::sync::Api, Repo, RepoType}; @@ -128,21 +128,24 @@ impl Decoder { // The model expects a batch dim but this inference loop does not handle // it so we add it at this point. let tokens_t = tokens_t.unsqueeze(0)?; - let logits = model.decoder.forward(&tokens_t, &audio_features, i == 0)?; - let logits = logits.squeeze(0)?; + let ys = model.decoder.forward(&tokens_t, &audio_features, i == 0)?; // Extract the no speech probability on the first iteration by looking at the first // token logits and the probability for the according token. if i == 0 { - no_speech_prob = softmax(&logits.get(0)?, 0)? - .get(self.no_speech_token as usize)? + let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?; + no_speech_prob = softmax(&logits, 0)? + .i(self.no_speech_token as usize)? .to_scalar::()? as f64; } - let (seq_len, _) = logits.dims2()?; - let logits = logits - .get(seq_len - 1)? - .broadcast_add(&self.suppress_tokens)?; + let (_, seq_len, _) = ys.dims3()?; + let logits = model + .decoder + .final_linear(&ys.i((..1, seq_len - 1..))?)? + .i(0)? + .i(0)?; + let logits = logits.broadcast_add(&self.suppress_tokens)?; let next_token = if t > 0f64 { let prs = softmax(&(&logits / t)?, 0)?; let logits_v: Vec = prs.to_vec1()?; @@ -159,7 +162,7 @@ impl Decoder { }; tokens.push(next_token); let prob = softmax(&logits, candle::D::Minus1)? - .get(next_token as usize)? + .i(next_token as usize)? .to_scalar::()? as f64; if next_token == self.eot_token || tokens.len() > model.config.max_target_positions { break; diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index 2fa04fb0..00d5707e 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -364,11 +364,13 @@ pub struct TextDecoder { ln: LayerNorm, mask: Tensor, span: tracing::Span, + span_final: tracing::Span, } impl TextDecoder { fn load(vb: VarBuilder, cfg: &Config) -> Result { let span = tracing::span!(tracing::Level::TRACE, "text-decoder"); + let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final"); let n_state = cfg.d_model; let n_head = cfg.decoder_attention_heads; let n_ctx = cfg.max_target_positions; @@ -384,7 +386,6 @@ impl TextDecoder { .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) .collect(); let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?; - Ok(Self { token_embedding, positional_embedding, @@ -392,6 +393,7 @@ impl TextDecoder { ln, mask, span, + span_final, }) } @@ -405,12 +407,16 @@ impl TextDecoder { for block in self.blocks.iter_mut() { x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?; } - let x = self.ln.forward(&x)?; - let w = self - .token_embedding - .embeddings() - .broadcast_left(x_dims[0])?; - let logits = x.matmul(&w.t()?)?; + self.ln.forward(&x) + } + + pub fn final_linear(&self, x: &Tensor) -> Result { + let b_size = x.dim(0)?; + let w = self.token_embedding.embeddings().broadcast_left(b_size)?; + let logits = { + let _enter = self.span_final.enter(); + x.matmul(&w.t()?)? + }; Ok(logits) } } @@ -432,16 +438,4 @@ impl Whisper { config, }) } - - #[allow(dead_code)] - pub fn forward( - &mut self, - mel: &Tensor, - tokens: &Tensor, - flush_kv_cache: bool, - ) -> Result { - let enc = self.encoder.forward(mel, flush_kv_cache)?; - let dec = self.decoder.forward(tokens, &enc, flush_kv_cache)?; - Ok(dec) - } }