mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add a couple kv-cache helper functions. (#2206)
This commit is contained in:
@ -47,6 +47,10 @@ impl Cache {
|
|||||||
self.all_data.narrow(self.dim, 0, self.current_seq_len)
|
self.all_data.narrow(self.dim, 0, self.current_seq_len)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.current_seq_len = 0
|
||||||
|
}
|
||||||
|
|
||||||
pub fn append(&mut self, src: &Tensor) -> Result<()> {
|
pub fn append(&mut self, src: &Tensor) -> Result<()> {
|
||||||
let seq_len = src.dim(self.dim)?;
|
let seq_len = src.dim(self.dim)?;
|
||||||
if self.current_seq_len + seq_len > self.max_seq_len {
|
if self.current_seq_len + seq_len > self.max_seq_len {
|
||||||
@ -83,6 +87,22 @@ impl KvCache {
|
|||||||
Ok(Self { k, v })
|
Ok(Self { k, v })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn k_cache(&self) -> &Cache {
|
||||||
|
&self.k
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn v_cache(&self) -> &Cache {
|
||||||
|
&self.v
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn k_cache_mut(&mut self) -> &mut Cache {
|
||||||
|
&mut self.k
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn v_cache_mut(&mut self) -> &mut Cache {
|
||||||
|
&mut self.v
|
||||||
|
}
|
||||||
|
|
||||||
pub fn k(&self) -> Result<Tensor> {
|
pub fn k(&self) -> Result<Tensor> {
|
||||||
self.k.current_data()
|
self.k.current_data()
|
||||||
}
|
}
|
||||||
@ -98,4 +118,13 @@ impl KvCache {
|
|||||||
let v = self.v.current_data()?;
|
let v = self.v.current_data()?;
|
||||||
Ok((k, v))
|
Ok((k, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn current_seq_len(&self) -> usize {
|
||||||
|
self.k.current_seq_len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.k.reset();
|
||||||
|
self.v.reset();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user