From 94aa234dfd85fece132d2e61999f409dd6321c5a Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 31 Aug 2023 10:37:44 +0200 Subject: [PATCH] Add the kv-cache to the whisper wasm version. (#689) * Add the kv-cache to the whisper wasm version. * Improve the handling of special tokens. --- candle-wasm-examples/whisper/src/bin/m.rs | 2 +- candle-wasm-examples/whisper/src/model.rs | 75 ++++++++++++++-------- candle-wasm-examples/whisper/src/worker.rs | 58 +++++++++++++---- 3 files changed, 95 insertions(+), 40 deletions(-) diff --git a/candle-wasm-examples/whisper/src/bin/m.rs b/candle-wasm-examples/whisper/src/bin/m.rs index 88b25267..0716a20d 100644 --- a/candle-wasm-examples/whisper/src/bin/m.rs +++ b/candle-wasm-examples/whisper/src/bin/m.rs @@ -27,7 +27,7 @@ impl Decoder { } #[wasm_bindgen] - pub fn decode(&self, wav_input: Vec) -> Result { + pub fn decode(&mut self, wav_input: Vec) -> Result { let segments = self .decoder .convert_and_run(&wav_input) diff --git a/candle-wasm-examples/whisper/src/model.rs b/candle-wasm-examples/whisper/src/model.rs index 239ceee5..8574124b 100644 --- a/candle-wasm-examples/whisper/src/model.rs +++ b/candle-wasm-examples/whisper/src/model.rs @@ -109,6 +109,7 @@ struct MultiHeadAttention { value: Linear, out: Linear, n_head: usize, + kv_cache: Option<(Tensor, Tensor)>, } impl MultiHeadAttention { @@ -123,14 +124,39 @@ impl MultiHeadAttention { value, out, n_head, + kv_cache: None, }) } - fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result { + fn forward( + &mut self, + x: &Tensor, + xa: Option<&Tensor>, + mask: Option<&Tensor>, + flush_cache: bool, + ) -> Result { let _timer = crate::Timer::new("MultiHeadAttention::forward"); let q = self.query.forward(x)?; - let k = self.key.forward(xa.unwrap_or(x))?; - let v = self.value.forward(xa.unwrap_or(x))?; + let (k, v) = match xa { + None => { + let k = self.key.forward(x)?; + let v = self.value.forward(x)?; + (k, v) + } + Some(x) => { + if flush_cache { + self.kv_cache = None; + } + if let Some((k, v)) = &self.kv_cache { + (k.clone(), v.clone()) + } else { + let k = self.key.forward(x)?; + let v = self.value.forward(x)?; + self.kv_cache = Some((k.clone(), v.clone())); + (k, v) + } + } + }; let wv = self.qkv_attention(&q, &k, &v, mask)?; let out = self.out.forward(&wv)?; Ok(out) @@ -151,18 +177,9 @@ impl MultiHeadAttention { ) -> Result { let (_, n_ctx, n_state) = q.dims3()?; let scale = ((n_state / self.n_head) as f64).powf(-0.25); - let q = { - let _timer = crate::Timer::new("q::reshape"); - (self.reshape_head(q)? * scale)? - }; - let k = { - let _timer = crate::Timer::new("k::reshape"); - (self.reshape_head(k)?.transpose(2, 3)? * scale)? - }; - let v = { - let _timer = crate::Timer::new("v::reshape-contiguous"); - self.reshape_head(v)?.contiguous()? - }; + let q = (self.reshape_head(q)? * scale)?; + let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?; + let v = self.reshape_head(v)?.contiguous()?; let mut qk = { let _timer = crate::Timer::new("qk::matmul"); q.matmul(&k)? @@ -218,12 +235,20 @@ impl ResidualAttentionBlock { }) } - fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result { + fn forward( + &mut self, + x: &Tensor, + xa: Option<&Tensor>, + mask: Option<&Tensor>, + flush_kv_cache: bool, + ) -> Result { let _timer = crate::Timer::new("ResidualAttentionBlock::forward"); - let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?; + let attn = self + .attn + .forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?; let mut x = (x + attn)?; - if let Some((attn, ln)) = &self.cross_attn { - x = (&x + attn.forward(&ln.forward(&x)?, xa, None)?)?; + if let Some((attn, ln)) = &mut self.cross_attn { + x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?; } let mlp = self.mlp_linear2.forward( &self @@ -294,7 +319,7 @@ impl AudioEncoder { ln_post, }) } - pub fn forward(&self, x: &Tensor) -> Result { + pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result { let _timer = crate::Timer::new("AudioEncoder::forward"); let x = { let _timer = crate::Timer::new("conv1::forward"); @@ -308,8 +333,8 @@ impl AudioEncoder { let (_bsize, seq_len, _hidden) = x.dims3()?; let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; let mut x = x.broadcast_add(&positional_embedding)?; - for block in self.blocks.iter() { - x = block.forward(&x, None, None)? + for block in self.blocks.iter_mut() { + x = block.forward(&x, None, None, flush_kv_cache)? } let x = self.ln_post.forward(&x)?; Ok(x) @@ -353,14 +378,14 @@ impl TextDecoder { }) } - pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result { + pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result { let x_dims = x.dims(); let last = x_dims[x_dims.len() - 1]; let token_embedding = self.token_embedding.forward(x)?; let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; let mut x = token_embedding.broadcast_add(&positional_embedding)?; - for block in self.blocks.iter() { - x = block.forward(&x, Some(xa), Some(&self.mask))?; + for block in self.blocks.iter_mut() { + x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?; } let x = self.ln.forward(&x)?; let w = self diff --git a/candle-wasm-examples/whisper/src/worker.rs b/candle-wasm-examples/whisper/src/worker.rs index 49b2cd09..4fb2223a 100644 --- a/candle-wasm-examples/whisper/src/worker.rs +++ b/candle-wasm-examples/whisper/src/worker.rs @@ -40,9 +40,13 @@ pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]; pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; // Tokenizer dependent bits. -pub const SOT_TOKEN: u32 = 50257; -pub const EOT_TOKEN: u32 = 50256; -pub const NO_SPEECH_TOKEN: u32 = 50361; +const SOT_TOKEN: &str = "<|startoftranscript|>"; +const TRANSCRIBE_TOKEN: &str = "<|transcribe|>"; +const TRANSLATE_TOKEN: &str = "<|translate|>"; +const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>"; +const EOT_TOKEN: &str = "<|endoftext|>"; +const NO_SPEECH_TOKEN: &str = "<|nocaptions|>"; + // From the _get_suppress_tokens function + 50362 (no timestamp) // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605 pub const SUPPRESS_TOKENS: [u32; 91] = [ @@ -71,11 +75,18 @@ pub struct Segment { pub dr: DecodingResult, } +#[allow(unused)] pub struct Decoder { model: Whisper, mel_filters: Vec, tokenizer: Tokenizer, suppress_tokens: Tensor, + sot_token: u32, + transcribe_token: u32, + translate_token: u32, + eot_token: u32, + no_speech_token: u32, + no_timestamps_token: u32, } impl Decoder { @@ -94,37 +105,49 @@ impl Decoder { } }) .collect(); + let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?; let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; + let sot_token = token_id(&tokenizer, SOT_TOKEN)?; + let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?; + let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?; + let eot_token = token_id(&tokenizer, EOT_TOKEN)?; + let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?; Ok(Self { model, mel_filters, tokenizer, suppress_tokens, + sot_token, + transcribe_token, + translate_token, + eot_token, + no_speech_token, + no_timestamps_token, }) } - fn decode(&self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result { - let model = &self.model; - let audio_features = model.encoder.forward(mel)?; + fn decode(&mut self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result { + let model = &mut self.model; + let audio_features = model.encoder.forward(mel, true)?; console_log!("audio features: {:?}", audio_features.dims()); let sample_len = model.config.max_target_positions / 2; let mut sum_logprob = 0f64; let mut no_speech_prob = f64::NAN; - let mut tokens = vec![SOT_TOKEN]; + let mut tokens = vec![self.sot_token, self.transcribe_token]; for i in 0..sample_len { let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?; // The model expects a batch dim but this inference loop does not handle // it so we add it at this point. let tokens_t = tokens_t.unsqueeze(0)?; - let logits = model.decoder.forward(&tokens_t, &audio_features)?; + let logits = model.decoder.forward(&tokens_t, &audio_features, i == 0)?; let logits = logits.squeeze(0)?; // Extract the no speech probability on the first iteration by looking at the first // token logits and the probability for the according token. if i == 0 { no_speech_prob = softmax(&logits.get(0)?, 0)? - .get(NO_SPEECH_TOKEN as usize)? + .get(self.no_speech_token as usize)? .to_scalar::()? as f64; } @@ -150,7 +173,7 @@ impl Decoder { let prob = softmax(&logits, candle::D::Minus1)? .get(next_token as usize)? .to_scalar::()? as f64; - if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions { + if next_token == self.eot_token || tokens.len() > model.config.max_target_positions { break; } sum_logprob += prob.ln(); @@ -169,7 +192,7 @@ impl Decoder { } fn decode_with_fallback( - &self, + &mut self, segment: &Tensor, rng: &mut StdRng, ) -> anyhow::Result { @@ -195,7 +218,7 @@ impl Decoder { unreachable!() } - fn run(&self, mel: &Tensor) -> anyhow::Result> { + fn run(&mut self, mel: &Tensor) -> anyhow::Result> { let mut rng = StdRng::seed_from_u64(299792458); let (_, _, content_frames) = mel.dims3()?; let mut seek = 0; @@ -239,7 +262,7 @@ impl Decoder { Ok(decoder) } - pub fn convert_and_run(&self, wav_input: &[u8]) -> anyhow::Result> { + pub fn convert_and_run(&mut self, wav_input: &[u8]) -> anyhow::Result> { let device = Device::Cpu; let mut wav_input = std::io::Cursor::new(wav_input); let (header, data) = wav::read(&mut wav_input)?; @@ -262,6 +285,13 @@ impl Decoder { } } +pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result { + match tokenizer.token_to_id(token) { + None => candle::bail!("no token-id for {token}"), + Some(id) => Ok(id), + } +} + // Communication to the worker happens through bincode, the model weights and configs are fetched // on the main thread and transfered via the following structure. #[derive(Serialize, Deserialize)] @@ -314,7 +344,7 @@ impl yew_agent::Worker for Worker { } Err(err) => Err(format!("model creation error {err:?}")), }, - WorkerInput::DecodeTask { wav_bytes } => match &self.decoder { + WorkerInput::DecodeTask { wav_bytes } => match &mut self.decoder { None => Err("model has not been set".to_string()), Some(decoder) => decoder .convert_and_run(&wav_bytes)