Add a couple kv-cache helper functions. (#2206)

This commit is contained in:
Laurent Mazare
2024-05-23 16:21:47 +02:00
committed by GitHub
parent 77ea479a18
commit 31cf64147b

View File

@ -47,6 +47,10 @@ impl Cache {
self.all_data.narrow(self.dim, 0, self.current_seq_len)
}
pub fn reset(&mut self) {
self.current_seq_len = 0
}
pub fn append(&mut self, src: &Tensor) -> Result<()> {
let seq_len = src.dim(self.dim)?;
if self.current_seq_len + seq_len > self.max_seq_len {
@ -83,6 +87,22 @@ impl KvCache {
Ok(Self { k, v })
}
pub fn k_cache(&self) -> &Cache {
&self.k
}
pub fn v_cache(&self) -> &Cache {
&self.v
}
pub fn k_cache_mut(&mut self) -> &mut Cache {
&mut self.k
}
pub fn v_cache_mut(&mut self) -> &mut Cache {
&mut self.v
}
pub fn k(&self) -> Result<Tensor> {
self.k.current_data()
}
@ -98,4 +118,13 @@ impl KvCache {
let v = self.v.current_data()?;
Ok((k, v))
}
pub fn current_seq_len(&self) -> usize {
self.k.current_seq_len()
}
pub fn reset(&mut self) {
self.k.reset();
self.v.reset();
}
}