diff --git a/candle-transformers/src/models/trocr.rs b/candle-transformers/src/models/trocr.rs index 785b06ca..49cd89f8 100644 --- a/candle-transformers/src/models/trocr.rs +++ b/candle-transformers/src/models/trocr.rs @@ -138,6 +138,10 @@ impl TrOCRAttention { }) } + fn reset_kv_cache(&mut self) { + self.kv_cache = None + } + fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result { tensor .reshape((bsz, (), self.num_heads, self.head_dim))? @@ -239,6 +243,10 @@ impl TrOCRDecoderLayer { }) } + fn reset_kv_cache(&mut self) { + self.self_attn.reset_kv_cache(); + } + fn forward( &mut self, xs: &Tensor, @@ -307,6 +315,10 @@ impl TrOCRDecoder { }) } + fn reset_kv_cache(&mut self) { + self.layers.iter_mut().for_each(|l| l.reset_kv_cache()) + } + pub fn forward( &mut self, xs: &Tensor, @@ -393,6 +405,10 @@ impl TrOCRForCausalLM { Ok(xs) } + + fn reset_kv_cache(&mut self) { + self.decoder.reset_kv_cache(); + } } #[derive(Debug, Clone)] @@ -431,4 +447,8 @@ impl TrOCRModel { self.decoder .forward(xs, Some(encoder_xs), past_kv_len, &mask) } + + pub fn reset_kv_cache(&mut self) { + self.decoder.reset_kv_cache(); + } }