mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
feat: adds reset_kv_cache (#1335)
This commit is contained in:
@ -138,6 +138,10 @@ impl TrOCRAttention {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn reset_kv_cache(&mut self) {
|
||||||
|
self.kv_cache = None
|
||||||
|
}
|
||||||
|
|
||||||
fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
|
fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
|
||||||
tensor
|
tensor
|
||||||
.reshape((bsz, (), self.num_heads, self.head_dim))?
|
.reshape((bsz, (), self.num_heads, self.head_dim))?
|
||||||
@ -239,6 +243,10 @@ impl TrOCRDecoderLayer {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn reset_kv_cache(&mut self) {
|
||||||
|
self.self_attn.reset_kv_cache();
|
||||||
|
}
|
||||||
|
|
||||||
fn forward(
|
fn forward(
|
||||||
&mut self,
|
&mut self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
@ -307,6 +315,10 @@ impl TrOCRDecoder {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn reset_kv_cache(&mut self) {
|
||||||
|
self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn forward(
|
pub fn forward(
|
||||||
&mut self,
|
&mut self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
@ -393,6 +405,10 @@ impl TrOCRForCausalLM {
|
|||||||
|
|
||||||
Ok(xs)
|
Ok(xs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn reset_kv_cache(&mut self) {
|
||||||
|
self.decoder.reset_kv_cache();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -431,4 +447,8 @@ impl TrOCRModel {
|
|||||||
self.decoder
|
self.decoder
|
||||||
.forward(xs, Some(encoder_xs), past_kv_len, &mask)
|
.forward(xs, Some(encoder_xs), past_kv_len, &mask)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn reset_kv_cache(&mut self) {
|
||||||
|
self.decoder.reset_kv_cache();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user