mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
feat: add clear_kv_cache to mistral and qmistral models (#1464)
This commit is contained in:
@ -297,6 +297,10 @@ impl Attention {
|
|||||||
.reshape((b_sz, q_len, self.hidden_size))?
|
.reshape((b_sz, q_len, self.hidden_size))?
|
||||||
.apply(&self.o_proj)
|
.apply(&self.o_proj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.kv_cache = None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -340,6 +344,10 @@ impl DecoderLayer {
|
|||||||
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
|
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
|
||||||
residual + xs
|
residual + xs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.self_attn.clear_kv_cache()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -423,4 +431,10 @@ impl Model {
|
|||||||
.apply(&self.norm)?
|
.apply(&self.norm)?
|
||||||
.apply(&self.lm_head)
|
.apply(&self.lm_head)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
layer.clear_kv_cache()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -198,6 +198,10 @@ impl Attention {
|
|||||||
.reshape((b_sz, q_len, self.hidden_size))?
|
.reshape((b_sz, q_len, self.hidden_size))?
|
||||||
.apply(&self.o_proj)
|
.apply(&self.o_proj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.kv_cache = None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -241,6 +245,10 @@ impl DecoderLayer {
|
|||||||
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
|
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
|
||||||
residual + xs
|
residual + xs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.self_attn.clear_kv_cache()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -322,4 +330,10 @@ impl Model {
|
|||||||
.apply(&self.norm)?
|
.apply(&self.norm)?
|
||||||
.apply(&self.lm_head)
|
.apply(&self.lm_head)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
layer.clear_kv_cache()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user