mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Expose a function to clear the KV cache on mixformers. (#964)
This commit is contained in:
@ -287,6 +287,10 @@ impl MHA {
|
||||
.flatten_from(D::Minus2)?;
|
||||
attn_output.apply(&self.out_proj)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -318,6 +322,10 @@ impl ParallelBlock {
|
||||
let feed_forward_hidden_states = self.mlp.forward(&xs)?;
|
||||
attn_outputs + feed_forward_hidden_states + residual
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.mixer.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -360,4 +368,8 @@ impl MixFormerSequentialForCausalLM {
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
|
||||
}
|
||||
}
|
||||
|
@ -268,6 +268,10 @@ impl MHA {
|
||||
.flatten_from(D::Minus2)?;
|
||||
attn_output.apply(&self.out_proj)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -299,6 +303,10 @@ impl ParallelBlock {
|
||||
let feed_forward_hidden_states = self.mlp.forward(&xs)?;
|
||||
attn_outputs + feed_forward_hidden_states + residual
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.mixer.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -341,4 +349,8 @@ impl MixFormerSequentialForCausalLM {
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user