Expose a function to clear the KV cache on mixformers. (#964)

This commit is contained in:
Laurent Mazare
2023-09-26 05:41:07 +01:00
committed by GitHub
parent a084f65f9a
commit 1fcac4afed
2 changed files with 24 additions and 0 deletions

View File

@ -287,6 +287,10 @@ impl MHA {
.flatten_from(D::Minus2)?;
attn_output.apply(&self.out_proj)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug)]
@ -318,6 +322,10 @@ impl ParallelBlock {
let feed_forward_hidden_states = self.mlp.forward(&xs)?;
attn_outputs + feed_forward_hidden_states + residual
}
fn clear_kv_cache(&mut self) {
self.mixer.clear_kv_cache()
}
}
#[derive(Debug)]
@ -360,4 +368,8 @@ impl MixFormerSequentialForCausalLM {
}
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
}
pub fn clear_kv_cache(&mut self) {
self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
}
}

View File

@ -268,6 +268,10 @@ impl MHA {
.flatten_from(D::Minus2)?;
attn_output.apply(&self.out_proj)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug)]
@ -299,6 +303,10 @@ impl ParallelBlock {
let feed_forward_hidden_states = self.mlp.forward(&xs)?;
attn_outputs + feed_forward_hidden_states + residual
}
fn clear_kv_cache(&mut self) {
self.mixer.clear_kv_cache()
}
}
#[derive(Debug)]
@ -341,4 +349,8 @@ impl MixFormerSequentialForCausalLM {
}
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
}
pub fn clear_kv_cache(&mut self) {
self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
}
}