mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Optimize the logit computations in the whisper example. (#434)
This commit is contained in:
@ -10,7 +10,7 @@ extern crate accelerate_src;
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle::{DType, Device, IndexOp, Tensor};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
@ -128,21 +128,24 @@ impl Decoder {
|
||||
// The model expects a batch dim but this inference loop does not handle
|
||||
// it so we add it at this point.
|
||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||
let logits = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let ys = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
|
||||
|
||||
// Extract the no speech probability on the first iteration by looking at the first
|
||||
// token logits and the probability for the according token.
|
||||
if i == 0 {
|
||||
no_speech_prob = softmax(&logits.get(0)?, 0)?
|
||||
.get(self.no_speech_token as usize)?
|
||||
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
no_speech_prob = softmax(&logits, 0)?
|
||||
.i(self.no_speech_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
}
|
||||
|
||||
let (seq_len, _) = logits.dims2()?;
|
||||
let logits = logits
|
||||
.get(seq_len - 1)?
|
||||
.broadcast_add(&self.suppress_tokens)?;
|
||||
let (_, seq_len, _) = ys.dims3()?;
|
||||
let logits = model
|
||||
.decoder
|
||||
.final_linear(&ys.i((..1, seq_len - 1..))?)?
|
||||
.i(0)?
|
||||
.i(0)?;
|
||||
let logits = logits.broadcast_add(&self.suppress_tokens)?;
|
||||
let next_token = if t > 0f64 {
|
||||
let prs = softmax(&(&logits / t)?, 0)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
@ -159,7 +162,7 @@ impl Decoder {
|
||||
};
|
||||
tokens.push(next_token);
|
||||
let prob = softmax(&logits, candle::D::Minus1)?
|
||||
.get(next_token as usize)?
|
||||
.i(next_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
|
||||
break;
|
||||
|
@ -364,11 +364,13 @@ pub struct TextDecoder {
|
||||
ln: LayerNorm,
|
||||
mask: Tensor,
|
||||
span: tracing::Span,
|
||||
span_final: tracing::Span,
|
||||
}
|
||||
|
||||
impl TextDecoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
|
||||
let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final");
|
||||
let n_state = cfg.d_model;
|
||||
let n_head = cfg.decoder_attention_heads;
|
||||
let n_ctx = cfg.max_target_positions;
|
||||
@ -384,7 +386,6 @@ impl TextDecoder {
|
||||
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||
.collect();
|
||||
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
|
||||
|
||||
Ok(Self {
|
||||
token_embedding,
|
||||
positional_embedding,
|
||||
@ -392,6 +393,7 @@ impl TextDecoder {
|
||||
ln,
|
||||
mask,
|
||||
span,
|
||||
span_final,
|
||||
})
|
||||
}
|
||||
|
||||
@ -405,12 +407,16 @@ impl TextDecoder {
|
||||
for block in self.blocks.iter_mut() {
|
||||
x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
|
||||
}
|
||||
let x = self.ln.forward(&x)?;
|
||||
let w = self
|
||||
.token_embedding
|
||||
.embeddings()
|
||||
.broadcast_left(x_dims[0])?;
|
||||
let logits = x.matmul(&w.t()?)?;
|
||||
self.ln.forward(&x)
|
||||
}
|
||||
|
||||
pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let b_size = x.dim(0)?;
|
||||
let w = self.token_embedding.embeddings().broadcast_left(b_size)?;
|
||||
let logits = {
|
||||
let _enter = self.span_final.enter();
|
||||
x.matmul(&w.t()?)?
|
||||
};
|
||||
Ok(logits)
|
||||
}
|
||||
}
|
||||
@ -432,16 +438,4 @@ impl Whisper {
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
mel: &Tensor,
|
||||
tokens: &Tensor,
|
||||
flush_kv_cache: bool,
|
||||
) -> Result<Tensor> {
|
||||
let enc = self.encoder.forward(mel, flush_kv_cache)?;
|
||||
let dec = self.decoder.forward(tokens, &enc, flush_kv_cache)?;
|
||||
Ok(dec)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user