mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Optimize the logit computations in the whisper example. (#434)
This commit is contained in:
@ -10,7 +10,7 @@ extern crate accelerate_src;
|
|||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, IndexOp, Tensor};
|
||||||
use candle_nn::{ops::softmax, VarBuilder};
|
use candle_nn::{ops::softmax, VarBuilder};
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
@ -128,21 +128,24 @@ impl Decoder {
|
|||||||
// The model expects a batch dim but this inference loop does not handle
|
// The model expects a batch dim but this inference loop does not handle
|
||||||
// it so we add it at this point.
|
// it so we add it at this point.
|
||||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||||
let logits = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
|
let ys = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
|
||||||
let logits = logits.squeeze(0)?;
|
|
||||||
|
|
||||||
// Extract the no speech probability on the first iteration by looking at the first
|
// Extract the no speech probability on the first iteration by looking at the first
|
||||||
// token logits and the probability for the according token.
|
// token logits and the probability for the according token.
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
no_speech_prob = softmax(&logits.get(0)?, 0)?
|
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||||
.get(self.no_speech_token as usize)?
|
no_speech_prob = softmax(&logits, 0)?
|
||||||
|
.i(self.no_speech_token as usize)?
|
||||||
.to_scalar::<f32>()? as f64;
|
.to_scalar::<f32>()? as f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
let (seq_len, _) = logits.dims2()?;
|
let (_, seq_len, _) = ys.dims3()?;
|
||||||
let logits = logits
|
let logits = model
|
||||||
.get(seq_len - 1)?
|
.decoder
|
||||||
.broadcast_add(&self.suppress_tokens)?;
|
.final_linear(&ys.i((..1, seq_len - 1..))?)?
|
||||||
|
.i(0)?
|
||||||
|
.i(0)?;
|
||||||
|
let logits = logits.broadcast_add(&self.suppress_tokens)?;
|
||||||
let next_token = if t > 0f64 {
|
let next_token = if t > 0f64 {
|
||||||
let prs = softmax(&(&logits / t)?, 0)?;
|
let prs = softmax(&(&logits / t)?, 0)?;
|
||||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||||
@ -159,7 +162,7 @@ impl Decoder {
|
|||||||
};
|
};
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
let prob = softmax(&logits, candle::D::Minus1)?
|
let prob = softmax(&logits, candle::D::Minus1)?
|
||||||
.get(next_token as usize)?
|
.i(next_token as usize)?
|
||||||
.to_scalar::<f32>()? as f64;
|
.to_scalar::<f32>()? as f64;
|
||||||
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
|
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
|
||||||
break;
|
break;
|
||||||
|
@ -364,11 +364,13 @@ pub struct TextDecoder {
|
|||||||
ln: LayerNorm,
|
ln: LayerNorm,
|
||||||
mask: Tensor,
|
mask: Tensor,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
|
span_final: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TextDecoder {
|
impl TextDecoder {
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
|
let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
|
||||||
|
let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final");
|
||||||
let n_state = cfg.d_model;
|
let n_state = cfg.d_model;
|
||||||
let n_head = cfg.decoder_attention_heads;
|
let n_head = cfg.decoder_attention_heads;
|
||||||
let n_ctx = cfg.max_target_positions;
|
let n_ctx = cfg.max_target_positions;
|
||||||
@ -384,7 +386,6 @@ impl TextDecoder {
|
|||||||
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||||
.collect();
|
.collect();
|
||||||
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
|
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
token_embedding,
|
token_embedding,
|
||||||
positional_embedding,
|
positional_embedding,
|
||||||
@ -392,6 +393,7 @@ impl TextDecoder {
|
|||||||
ln,
|
ln,
|
||||||
mask,
|
mask,
|
||||||
span,
|
span,
|
||||||
|
span_final,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -405,12 +407,16 @@ impl TextDecoder {
|
|||||||
for block in self.blocks.iter_mut() {
|
for block in self.blocks.iter_mut() {
|
||||||
x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
|
x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
|
||||||
}
|
}
|
||||||
let x = self.ln.forward(&x)?;
|
self.ln.forward(&x)
|
||||||
let w = self
|
}
|
||||||
.token_embedding
|
|
||||||
.embeddings()
|
pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
.broadcast_left(x_dims[0])?;
|
let b_size = x.dim(0)?;
|
||||||
let logits = x.matmul(&w.t()?)?;
|
let w = self.token_embedding.embeddings().broadcast_left(b_size)?;
|
||||||
|
let logits = {
|
||||||
|
let _enter = self.span_final.enter();
|
||||||
|
x.matmul(&w.t()?)?
|
||||||
|
};
|
||||||
Ok(logits)
|
Ok(logits)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -432,16 +438,4 @@ impl Whisper {
|
|||||||
config,
|
config,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub fn forward(
|
|
||||||
&mut self,
|
|
||||||
mel: &Tensor,
|
|
||||||
tokens: &Tensor,
|
|
||||||
flush_kv_cache: bool,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let enc = self.encoder.forward(mel, flush_kv_cache)?;
|
|
||||||
let dec = self.decoder.forward(tokens, &enc, flush_kv_cache)?;
|
|
||||||
Ok(dec)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user