mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Add a couple kv-cache helper functions. (#2206)
This commit is contained in:
@ -47,6 +47,10 @@ impl Cache {
|
||||
self.all_data.narrow(self.dim, 0, self.current_seq_len)
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.current_seq_len = 0
|
||||
}
|
||||
|
||||
pub fn append(&mut self, src: &Tensor) -> Result<()> {
|
||||
let seq_len = src.dim(self.dim)?;
|
||||
if self.current_seq_len + seq_len > self.max_seq_len {
|
||||
@ -83,6 +87,22 @@ impl KvCache {
|
||||
Ok(Self { k, v })
|
||||
}
|
||||
|
||||
pub fn k_cache(&self) -> &Cache {
|
||||
&self.k
|
||||
}
|
||||
|
||||
pub fn v_cache(&self) -> &Cache {
|
||||
&self.v
|
||||
}
|
||||
|
||||
pub fn k_cache_mut(&mut self) -> &mut Cache {
|
||||
&mut self.k
|
||||
}
|
||||
|
||||
pub fn v_cache_mut(&mut self) -> &mut Cache {
|
||||
&mut self.v
|
||||
}
|
||||
|
||||
pub fn k(&self) -> Result<Tensor> {
|
||||
self.k.current_data()
|
||||
}
|
||||
@ -98,4 +118,13 @@ impl KvCache {
|
||||
let v = self.v.current_data()?;
|
||||
Ok((k, v))
|
||||
}
|
||||
|
||||
pub fn current_seq_len(&self) -> usize {
|
||||
self.k.current_seq_len()
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.k.reset();
|
||||
self.v.reset();
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user