Add a KV cache to whisper. (#426)

This commit is contained in:
Laurent Mazare
2023-08-12 22:17:08 +02:00
committed by GitHub
parent a0908d212c
commit 60cd1551ca
3 changed files with 63 additions and 24 deletions

View File

@ -105,6 +105,7 @@ struct MultiHeadAttention {
out: Linear,
n_head: usize,
span: tracing::Span,
kv_cache: Option<(Tensor, Tensor)>,
}
impl MultiHeadAttention {
@ -121,14 +122,39 @@ impl MultiHeadAttention {
out,
n_head,
span,
kv_cache: None,
})
}
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
fn forward(
&mut self,
x: &Tensor,
xa: Option<&Tensor>,
mask: Option<&Tensor>,
flush_cache: bool,
) -> Result<Tensor> {
let _enter = self.span.enter();
let q = self.query.forward(x)?;
let k = self.key.forward(xa.unwrap_or(x))?;
let v = self.value.forward(xa.unwrap_or(x))?;
let (k, v) = match xa {
None => {
let k = self.key.forward(x)?;
let v = self.value.forward(x)?;
(k, v)
}
Some(x) => {
if flush_cache {
self.kv_cache = None;
}
if let Some((k, v)) = &self.kv_cache {
(k.clone(), v.clone())
} else {
let k = self.key.forward(x)?;
let v = self.value.forward(x)?;
self.kv_cache = Some((k.clone(), v.clone()));
(k, v)
}
}
};
let wv = self.qkv_attention(&q, &k, &v, mask)?;
let out = self.out.forward(&wv)?;
Ok(out)
@ -201,12 +227,20 @@ impl ResidualAttentionBlock {
})
}
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
fn forward(
&mut self,
x: &Tensor,
xa: Option<&Tensor>,
mask: Option<&Tensor>,
flush_kv_cache: bool,
) -> Result<Tensor> {
let _enter = self.span.enter();
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
let attn = self
.attn
.forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
let mut x = (x + attn)?;
if let Some((attn, ln)) = &self.cross_attn {
x = (&x + attn.forward(&ln.forward(&x)?, xa, None)?)?;
if let Some((attn, ln)) = &mut self.cross_attn {
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
}
let mlp = self.mlp_linear2.forward(
&self
@ -283,7 +317,7 @@ impl AudioEncoder {
})
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
let _enter = self.span.enter();
let x = {
let _enter = self.conv1_span.enter();
@ -297,8 +331,8 @@ impl AudioEncoder {
let (_bsize, seq_len, _hidden) = x.dims3()?;
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
let mut x = x.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter() {
x = block.forward(&x, None, None)?
for block in self.blocks.iter_mut() {
x = block.forward(&x, None, None, flush_kv_cache)?
}
let x = self.ln_post.forward(&x)?;
Ok(x)
@ -344,15 +378,15 @@ impl TextDecoder {
})
}
pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
let _enter = self.span.enter();
let x_dims = x.dims();
let last = x_dims[x_dims.len() - 1];
let token_embedding = self.token_embedding.forward(x)?;
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter() {
x = block.forward(&x, Some(xa), Some(&self.mask))?;
for block in self.blocks.iter_mut() {
x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
}
let x = self.ln.forward(&x)?;
let w = self
@ -383,9 +417,14 @@ impl Whisper {
}
#[allow(dead_code)]
pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
let enc = self.encoder.forward(mel)?;
let dec = self.decoder.forward(tokens, &enc)?;
pub fn forward(
&mut self,
mel: &Tensor,
tokens: &Tensor,
flush_kv_cache: bool,
) -> Result<Tensor> {
let enc = self.encoder.forward(mel, flush_kv_cache)?;
let dec = self.decoder.forward(tokens, &enc, flush_kv_cache)?;
Ok(dec)
}
}