Add the kv-cache to the whisper wasm version. (#689)

* Add the kv-cache to the whisper wasm version.

* Improve the handling of special tokens.
This commit is contained in:
Laurent Mazare
2023-08-31 10:37:44 +02:00
committed by GitHub
parent db59816087
commit 94aa234dfd
3 changed files with 95 additions and 40 deletions

View File

@ -27,7 +27,7 @@ impl Decoder {
}
#[wasm_bindgen]
pub fn decode(&self, wav_input: Vec<u8>) -> Result<String, JsError> {
pub fn decode(&mut self, wav_input: Vec<u8>) -> Result<String, JsError> {
let segments = self
.decoder
.convert_and_run(&wav_input)

View File

@ -109,6 +109,7 @@ struct MultiHeadAttention {
value: Linear,
out: Linear,
n_head: usize,
kv_cache: Option<(Tensor, Tensor)>,
}
impl MultiHeadAttention {
@ -123,14 +124,39 @@ impl MultiHeadAttention {
value,
out,
n_head,
kv_cache: None,
})
}
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
fn forward(
&mut self,
x: &Tensor,
xa: Option<&Tensor>,
mask: Option<&Tensor>,
flush_cache: bool,
) -> Result<Tensor> {
let _timer = crate::Timer::new("MultiHeadAttention::forward");
let q = self.query.forward(x)?;
let k = self.key.forward(xa.unwrap_or(x))?;
let v = self.value.forward(xa.unwrap_or(x))?;
let (k, v) = match xa {
None => {
let k = self.key.forward(x)?;
let v = self.value.forward(x)?;
(k, v)
}
Some(x) => {
if flush_cache {
self.kv_cache = None;
}
if let Some((k, v)) = &self.kv_cache {
(k.clone(), v.clone())
} else {
let k = self.key.forward(x)?;
let v = self.value.forward(x)?;
self.kv_cache = Some((k.clone(), v.clone()));
(k, v)
}
}
};
let wv = self.qkv_attention(&q, &k, &v, mask)?;
let out = self.out.forward(&wv)?;
Ok(out)
@ -151,18 +177,9 @@ impl MultiHeadAttention {
) -> Result<Tensor> {
let (_, n_ctx, n_state) = q.dims3()?;
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
let q = {
let _timer = crate::Timer::new("q::reshape");
(self.reshape_head(q)? * scale)?
};
let k = {
let _timer = crate::Timer::new("k::reshape");
(self.reshape_head(k)?.transpose(2, 3)? * scale)?
};
let v = {
let _timer = crate::Timer::new("v::reshape-contiguous");
self.reshape_head(v)?.contiguous()?
};
let q = (self.reshape_head(q)? * scale)?;
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
let v = self.reshape_head(v)?.contiguous()?;
let mut qk = {
let _timer = crate::Timer::new("qk::matmul");
q.matmul(&k)?
@ -218,12 +235,20 @@ impl ResidualAttentionBlock {
})
}
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
fn forward(
&mut self,
x: &Tensor,
xa: Option<&Tensor>,
mask: Option<&Tensor>,
flush_kv_cache: bool,
) -> Result<Tensor> {
let _timer = crate::Timer::new("ResidualAttentionBlock::forward");
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
let attn = self
.attn
.forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
let mut x = (x + attn)?;
if let Some((attn, ln)) = &self.cross_attn {
x = (&x + attn.forward(&ln.forward(&x)?, xa, None)?)?;
if let Some((attn, ln)) = &mut self.cross_attn {
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
}
let mlp = self.mlp_linear2.forward(
&self
@ -294,7 +319,7 @@ impl AudioEncoder {
ln_post,
})
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
let _timer = crate::Timer::new("AudioEncoder::forward");
let x = {
let _timer = crate::Timer::new("conv1::forward");
@ -308,8 +333,8 @@ impl AudioEncoder {
let (_bsize, seq_len, _hidden) = x.dims3()?;
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
let mut x = x.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter() {
x = block.forward(&x, None, None)?
for block in self.blocks.iter_mut() {
x = block.forward(&x, None, None, flush_kv_cache)?
}
let x = self.ln_post.forward(&x)?;
Ok(x)
@ -353,14 +378,14 @@ impl TextDecoder {
})
}
pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
let x_dims = x.dims();
let last = x_dims[x_dims.len() - 1];
let token_embedding = self.token_embedding.forward(x)?;
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter() {
x = block.forward(&x, Some(xa), Some(&self.mask))?;
for block in self.blocks.iter_mut() {
x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
}
let x = self.ln.forward(&x)?;
let w = self

View File

@ -40,9 +40,13 @@ pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
// Tokenizer dependent bits.
pub const SOT_TOKEN: u32 = 50257;
pub const EOT_TOKEN: u32 = 50256;
pub const NO_SPEECH_TOKEN: u32 = 50361;
const SOT_TOKEN: &str = "<|startoftranscript|>";
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
const TRANSLATE_TOKEN: &str = "<|translate|>";
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
const EOT_TOKEN: &str = "<|endoftext|>";
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
// From the _get_suppress_tokens function + 50362 (no timestamp)
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
pub const SUPPRESS_TOKENS: [u32; 91] = [
@ -71,11 +75,18 @@ pub struct Segment {
pub dr: DecodingResult,
}
#[allow(unused)]
pub struct Decoder {
model: Whisper,
mel_filters: Vec<f32>,
tokenizer: Tokenizer,
suppress_tokens: Tensor,
sot_token: u32,
transcribe_token: u32,
translate_token: u32,
eot_token: u32,
no_speech_token: u32,
no_timestamps_token: u32,
}
impl Decoder {
@ -94,37 +105,49 @@ impl Decoder {
}
})
.collect();
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
Ok(Self {
model,
mel_filters,
tokenizer,
suppress_tokens,
sot_token,
transcribe_token,
translate_token,
eot_token,
no_speech_token,
no_timestamps_token,
})
}
fn decode(&self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> {
let model = &self.model;
let audio_features = model.encoder.forward(mel)?;
fn decode(&mut self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> {
let model = &mut self.model;
let audio_features = model.encoder.forward(mel, true)?;
console_log!("audio features: {:?}", audio_features.dims());
let sample_len = model.config.max_target_positions / 2;
let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN;
let mut tokens = vec![SOT_TOKEN];
let mut tokens = vec![self.sot_token, self.transcribe_token];
for i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
// The model expects a batch dim but this inference loop does not handle
// it so we add it at this point.
let tokens_t = tokens_t.unsqueeze(0)?;
let logits = model.decoder.forward(&tokens_t, &audio_features)?;
let logits = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
let logits = logits.squeeze(0)?;
// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
no_speech_prob = softmax(&logits.get(0)?, 0)?
.get(NO_SPEECH_TOKEN as usize)?
.get(self.no_speech_token as usize)?
.to_scalar::<f32>()? as f64;
}
@ -150,7 +173,7 @@ impl Decoder {
let prob = softmax(&logits, candle::D::Minus1)?
.get(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
break;
}
sum_logprob += prob.ln();
@ -169,7 +192,7 @@ impl Decoder {
}
fn decode_with_fallback(
&self,
&mut self,
segment: &Tensor,
rng: &mut StdRng,
) -> anyhow::Result<DecodingResult> {
@ -195,7 +218,7 @@ impl Decoder {
unreachable!()
}
fn run(&self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
fn run(&mut self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
let mut rng = StdRng::seed_from_u64(299792458);
let (_, _, content_frames) = mel.dims3()?;
let mut seek = 0;
@ -239,7 +262,7 @@ impl Decoder {
Ok(decoder)
}
pub fn convert_and_run(&self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
pub fn convert_and_run(&mut self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
let device = Device::Cpu;
let mut wav_input = std::io::Cursor::new(wav_input);
let (header, data) = wav::read(&mut wav_input)?;
@ -262,6 +285,13 @@ impl Decoder {
}
}
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
match tokenizer.token_to_id(token) {
None => candle::bail!("no token-id for {token}"),
Some(id) => Ok(id),
}
}
// Communication to the worker happens through bincode, the model weights and configs are fetched
// on the main thread and transfered via the following structure.
#[derive(Serialize, Deserialize)]
@ -314,7 +344,7 @@ impl yew_agent::Worker for Worker {
}
Err(err) => Err(format!("model creation error {err:?}")),
},
WorkerInput::DecodeTask { wav_bytes } => match &self.decoder {
WorkerInput::DecodeTask { wav_bytes } => match &mut self.decoder {
None => Err("model has not been set".to_string()),
Some(decoder) => decoder
.convert_and_run(&wav_bytes)