mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add the kv-cache to the whisper wasm version. (#689)
* Add the kv-cache to the whisper wasm version. * Improve the handling of special tokens.
This commit is contained in:
@ -27,7 +27,7 @@ impl Decoder {
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub fn decode(&self, wav_input: Vec<u8>) -> Result<String, JsError> {
|
||||
pub fn decode(&mut self, wav_input: Vec<u8>) -> Result<String, JsError> {
|
||||
let segments = self
|
||||
.decoder
|
||||
.convert_and_run(&wav_input)
|
||||
|
@ -109,6 +109,7 @@ struct MultiHeadAttention {
|
||||
value: Linear,
|
||||
out: Linear,
|
||||
n_head: usize,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl MultiHeadAttention {
|
||||
@ -123,14 +124,39 @@ impl MultiHeadAttention {
|
||||
value,
|
||||
out,
|
||||
n_head,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
fn forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
xa: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
flush_cache: bool,
|
||||
) -> Result<Tensor> {
|
||||
let _timer = crate::Timer::new("MultiHeadAttention::forward");
|
||||
let q = self.query.forward(x)?;
|
||||
let k = self.key.forward(xa.unwrap_or(x))?;
|
||||
let v = self.value.forward(xa.unwrap_or(x))?;
|
||||
let (k, v) = match xa {
|
||||
None => {
|
||||
let k = self.key.forward(x)?;
|
||||
let v = self.value.forward(x)?;
|
||||
(k, v)
|
||||
}
|
||||
Some(x) => {
|
||||
if flush_cache {
|
||||
self.kv_cache = None;
|
||||
}
|
||||
if let Some((k, v)) = &self.kv_cache {
|
||||
(k.clone(), v.clone())
|
||||
} else {
|
||||
let k = self.key.forward(x)?;
|
||||
let v = self.value.forward(x)?;
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
};
|
||||
let wv = self.qkv_attention(&q, &k, &v, mask)?;
|
||||
let out = self.out.forward(&wv)?;
|
||||
Ok(out)
|
||||
@ -151,18 +177,9 @@ impl MultiHeadAttention {
|
||||
) -> Result<Tensor> {
|
||||
let (_, n_ctx, n_state) = q.dims3()?;
|
||||
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
||||
let q = {
|
||||
let _timer = crate::Timer::new("q::reshape");
|
||||
(self.reshape_head(q)? * scale)?
|
||||
};
|
||||
let k = {
|
||||
let _timer = crate::Timer::new("k::reshape");
|
||||
(self.reshape_head(k)?.transpose(2, 3)? * scale)?
|
||||
};
|
||||
let v = {
|
||||
let _timer = crate::Timer::new("v::reshape-contiguous");
|
||||
self.reshape_head(v)?.contiguous()?
|
||||
};
|
||||
let q = (self.reshape_head(q)? * scale)?;
|
||||
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
|
||||
let v = self.reshape_head(v)?.contiguous()?;
|
||||
let mut qk = {
|
||||
let _timer = crate::Timer::new("qk::matmul");
|
||||
q.matmul(&k)?
|
||||
@ -218,12 +235,20 @@ impl ResidualAttentionBlock {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
fn forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
xa: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
flush_kv_cache: bool,
|
||||
) -> Result<Tensor> {
|
||||
let _timer = crate::Timer::new("ResidualAttentionBlock::forward");
|
||||
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
|
||||
let attn = self
|
||||
.attn
|
||||
.forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
|
||||
let mut x = (x + attn)?;
|
||||
if let Some((attn, ln)) = &self.cross_attn {
|
||||
x = (&x + attn.forward(&ln.forward(&x)?, xa, None)?)?;
|
||||
if let Some((attn, ln)) = &mut self.cross_attn {
|
||||
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
|
||||
}
|
||||
let mlp = self.mlp_linear2.forward(
|
||||
&self
|
||||
@ -294,7 +319,7 @@ impl AudioEncoder {
|
||||
ln_post,
|
||||
})
|
||||
}
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||
let _timer = crate::Timer::new("AudioEncoder::forward");
|
||||
let x = {
|
||||
let _timer = crate::Timer::new("conv1::forward");
|
||||
@ -308,8 +333,8 @@ impl AudioEncoder {
|
||||
let (_bsize, seq_len, _hidden) = x.dims3()?;
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
||||
let mut x = x.broadcast_add(&positional_embedding)?;
|
||||
for block in self.blocks.iter() {
|
||||
x = block.forward(&x, None, None)?
|
||||
for block in self.blocks.iter_mut() {
|
||||
x = block.forward(&x, None, None, flush_kv_cache)?
|
||||
}
|
||||
let x = self.ln_post.forward(&x)?;
|
||||
Ok(x)
|
||||
@ -353,14 +378,14 @@ impl TextDecoder {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||
let x_dims = x.dims();
|
||||
let last = x_dims[x_dims.len() - 1];
|
||||
let token_embedding = self.token_embedding.forward(x)?;
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
||||
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
||||
for block in self.blocks.iter() {
|
||||
x = block.forward(&x, Some(xa), Some(&self.mask))?;
|
||||
for block in self.blocks.iter_mut() {
|
||||
x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
|
||||
}
|
||||
let x = self.ln.forward(&x)?;
|
||||
let w = self
|
||||
|
@ -40,9 +40,13 @@ pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
|
||||
pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
||||
|
||||
// Tokenizer dependent bits.
|
||||
pub const SOT_TOKEN: u32 = 50257;
|
||||
pub const EOT_TOKEN: u32 = 50256;
|
||||
pub const NO_SPEECH_TOKEN: u32 = 50361;
|
||||
const SOT_TOKEN: &str = "<|startoftranscript|>";
|
||||
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
||||
const TRANSLATE_TOKEN: &str = "<|translate|>";
|
||||
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
||||
const EOT_TOKEN: &str = "<|endoftext|>";
|
||||
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
||||
|
||||
// From the _get_suppress_tokens function + 50362 (no timestamp)
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
|
||||
pub const SUPPRESS_TOKENS: [u32; 91] = [
|
||||
@ -71,11 +75,18 @@ pub struct Segment {
|
||||
pub dr: DecodingResult,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub struct Decoder {
|
||||
model: Whisper,
|
||||
mel_filters: Vec<f32>,
|
||||
tokenizer: Tokenizer,
|
||||
suppress_tokens: Tensor,
|
||||
sot_token: u32,
|
||||
transcribe_token: u32,
|
||||
translate_token: u32,
|
||||
eot_token: u32,
|
||||
no_speech_token: u32,
|
||||
no_timestamps_token: u32,
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
@ -94,37 +105,49 @@ impl Decoder {
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
|
||||
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
||||
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
|
||||
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
|
||||
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
|
||||
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
|
||||
Ok(Self {
|
||||
model,
|
||||
mel_filters,
|
||||
tokenizer,
|
||||
suppress_tokens,
|
||||
sot_token,
|
||||
transcribe_token,
|
||||
translate_token,
|
||||
eot_token,
|
||||
no_speech_token,
|
||||
no_timestamps_token,
|
||||
})
|
||||
}
|
||||
|
||||
fn decode(&self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> {
|
||||
let model = &self.model;
|
||||
let audio_features = model.encoder.forward(mel)?;
|
||||
fn decode(&mut self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> {
|
||||
let model = &mut self.model;
|
||||
let audio_features = model.encoder.forward(mel, true)?;
|
||||
console_log!("audio features: {:?}", audio_features.dims());
|
||||
let sample_len = model.config.max_target_positions / 2;
|
||||
let mut sum_logprob = 0f64;
|
||||
let mut no_speech_prob = f64::NAN;
|
||||
let mut tokens = vec![SOT_TOKEN];
|
||||
let mut tokens = vec![self.sot_token, self.transcribe_token];
|
||||
for i in 0..sample_len {
|
||||
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
|
||||
|
||||
// The model expects a batch dim but this inference loop does not handle
|
||||
// it so we add it at this point.
|
||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||
let logits = model.decoder.forward(&tokens_t, &audio_features)?;
|
||||
let logits = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
|
||||
// Extract the no speech probability on the first iteration by looking at the first
|
||||
// token logits and the probability for the according token.
|
||||
if i == 0 {
|
||||
no_speech_prob = softmax(&logits.get(0)?, 0)?
|
||||
.get(NO_SPEECH_TOKEN as usize)?
|
||||
.get(self.no_speech_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
}
|
||||
|
||||
@ -150,7 +173,7 @@ impl Decoder {
|
||||
let prob = softmax(&logits, candle::D::Minus1)?
|
||||
.get(next_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
|
||||
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
|
||||
break;
|
||||
}
|
||||
sum_logprob += prob.ln();
|
||||
@ -169,7 +192,7 @@ impl Decoder {
|
||||
}
|
||||
|
||||
fn decode_with_fallback(
|
||||
&self,
|
||||
&mut self,
|
||||
segment: &Tensor,
|
||||
rng: &mut StdRng,
|
||||
) -> anyhow::Result<DecodingResult> {
|
||||
@ -195,7 +218,7 @@ impl Decoder {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
fn run(&self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
|
||||
fn run(&mut self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
|
||||
let mut rng = StdRng::seed_from_u64(299792458);
|
||||
let (_, _, content_frames) = mel.dims3()?;
|
||||
let mut seek = 0;
|
||||
@ -239,7 +262,7 @@ impl Decoder {
|
||||
Ok(decoder)
|
||||
}
|
||||
|
||||
pub fn convert_and_run(&self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
|
||||
pub fn convert_and_run(&mut self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
|
||||
let device = Device::Cpu;
|
||||
let mut wav_input = std::io::Cursor::new(wav_input);
|
||||
let (header, data) = wav::read(&mut wav_input)?;
|
||||
@ -262,6 +285,13 @@ impl Decoder {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
|
||||
match tokenizer.token_to_id(token) {
|
||||
None => candle::bail!("no token-id for {token}"),
|
||||
Some(id) => Ok(id),
|
||||
}
|
||||
}
|
||||
|
||||
// Communication to the worker happens through bincode, the model weights and configs are fetched
|
||||
// on the main thread and transfered via the following structure.
|
||||
#[derive(Serialize, Deserialize)]
|
||||
@ -314,7 +344,7 @@ impl yew_agent::Worker for Worker {
|
||||
}
|
||||
Err(err) => Err(format!("model creation error {err:?}")),
|
||||
},
|
||||
WorkerInput::DecodeTask { wav_bytes } => match &self.decoder {
|
||||
WorkerInput::DecodeTask { wav_bytes } => match &mut self.decoder {
|
||||
None => Err("model has not been set".to_string()),
|
||||
Some(decoder) => decoder
|
||||
.convert_and_run(&wav_bytes)
|
||||
|
Reference in New Issue
Block a user