mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Expose a function to clear the KV cache on mixformers. (#964)
This commit is contained in:
@ -287,6 +287,10 @@ impl MHA {
|
|||||||
.flatten_from(D::Minus2)?;
|
.flatten_from(D::Minus2)?;
|
||||||
attn_output.apply(&self.out_proj)
|
attn_output.apply(&self.out_proj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.kv_cache = None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -318,6 +322,10 @@ impl ParallelBlock {
|
|||||||
let feed_forward_hidden_states = self.mlp.forward(&xs)?;
|
let feed_forward_hidden_states = self.mlp.forward(&xs)?;
|
||||||
attn_outputs + feed_forward_hidden_states + residual
|
attn_outputs + feed_forward_hidden_states + residual
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.mixer.clear_kv_cache()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -360,4 +368,8 @@ impl MixFormerSequentialForCausalLM {
|
|||||||
}
|
}
|
||||||
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
|
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -268,6 +268,10 @@ impl MHA {
|
|||||||
.flatten_from(D::Minus2)?;
|
.flatten_from(D::Minus2)?;
|
||||||
attn_output.apply(&self.out_proj)
|
attn_output.apply(&self.out_proj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.kv_cache = None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -299,6 +303,10 @@ impl ParallelBlock {
|
|||||||
let feed_forward_hidden_states = self.mlp.forward(&xs)?;
|
let feed_forward_hidden_states = self.mlp.forward(&xs)?;
|
||||||
attn_outputs + feed_forward_hidden_states + residual
|
attn_outputs + feed_forward_hidden_states + residual
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.mixer.clear_kv_cache()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -341,4 +349,8 @@ impl MixFormerSequentialForCausalLM {
|
|||||||
}
|
}
|
||||||
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
|
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user