mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
feat: adds reset_kv_cache (#1335)
This commit is contained in:
@ -138,6 +138,10 @@ impl TrOCRAttention {
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
|
||||
fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
|
||||
tensor
|
||||
.reshape((bsz, (), self.num_heads, self.head_dim))?
|
||||
@ -239,6 +243,10 @@ impl TrOCRDecoderLayer {
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.self_attn.reset_kv_cache();
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
@ -307,6 +315,10 @@ impl TrOCRDecoder {
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
@ -393,6 +405,10 @@ impl TrOCRForCausalLM {
|
||||
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.decoder.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -431,4 +447,8 @@ impl TrOCRModel {
|
||||
self.decoder
|
||||
.forward(xs, Some(encoder_xs), past_kv_len, &mask)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
self.decoder.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user