feat: adds reset_kv_cache (#1335)

This commit is contained in:
drbh
2023-11-16 16:17:42 -05:00
committed by GitHub
parent 92a05b51cf
commit a1f41ab37b

View File

@ -138,6 +138,10 @@ impl TrOCRAttention {
})
}
fn reset_kv_cache(&mut self) {
self.kv_cache = None
}
fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
tensor
.reshape((bsz, (), self.num_heads, self.head_dim))?
@ -239,6 +243,10 @@ impl TrOCRDecoderLayer {
})
}
fn reset_kv_cache(&mut self) {
self.self_attn.reset_kv_cache();
}
fn forward(
&mut self,
xs: &Tensor,
@ -307,6 +315,10 @@ impl TrOCRDecoder {
})
}
fn reset_kv_cache(&mut self) {
self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
}
pub fn forward(
&mut self,
xs: &Tensor,
@ -393,6 +405,10 @@ impl TrOCRForCausalLM {
Ok(xs)
}
fn reset_kv_cache(&mut self) {
self.decoder.reset_kv_cache();
}
}
#[derive(Debug, Clone)]
@ -431,4 +447,8 @@ impl TrOCRModel {
self.decoder
.forward(xs, Some(encoder_xs), past_kv_len, &mask)
}
pub fn reset_kv_cache(&mut self) {
self.decoder.reset_kv_cache();
}
}