use candle::{Result, Tensor}; #[derive(Debug, Clone)] pub struct Cache { // all_data is an option on a Tensor, this makes it possible to only create the actual tensor // on the first call where the batch size is easily known. // Also this makes it safe to clone a KvCache that has been reseted (as in it will not share // its internal state with the cloned instance). all_data: Option, dim: usize, current_seq_len: usize, max_seq_len: usize, } impl Cache { pub fn new(dim: usize, max_seq_len: usize) -> Self { Self { all_data: None, dim, current_seq_len: 0, max_seq_len, } } pub fn dim(&self) -> usize { self.dim } pub fn current_seq_len(&self) -> usize { self.current_seq_len } pub fn max_seq_len(&self) -> usize { self.max_seq_len } pub fn all_data(&self) -> &Option { &self.all_data } pub fn current_data(&self) -> Result> { let data = match self.all_data.as_ref() { None => None, Some(d) => Some(d.narrow(self.dim, 0, self.current_seq_len)?), }; Ok(data) } pub fn reset(&mut self) { self.current_seq_len = 0; self.all_data = None; } pub fn append(&mut self, src: &Tensor) -> Result<()> { let seq_len = src.dim(self.dim)?; // This doesn't seem very idiomatic but because the creation can fail, it's tricky to use // self.all_data.get_or_insert_with. if self.all_data.is_none() { let mut shape = src.dims().to_vec(); shape[self.dim] = self.max_seq_len; let ad = Tensor::zeros(shape, src.dtype(), src.device())?; self.all_data = Some(ad) }; let ad = self.all_data.as_mut().unwrap(); if self.current_seq_len + seq_len > self.max_seq_len { candle::bail!( "kv-cache: above max-seq-len {}+{seq_len}>{}", self.current_seq_len, self.max_seq_len ) } ad.slice_set(src, self.dim, self.current_seq_len)?; self.current_seq_len += seq_len; Ok(()) } } #[derive(Debug, Clone)] pub struct KvCache { k: Cache, v: Cache, } impl KvCache { pub fn new(dim: usize, max_seq_len: usize) -> Self { let k = Cache::new(dim, max_seq_len); let v = Cache::new(dim, max_seq_len); Self { k, v } } pub fn k_cache(&self) -> &Cache { &self.k } pub fn v_cache(&self) -> &Cache { &self.v } pub fn k_cache_mut(&mut self) -> &mut Cache { &mut self.k } pub fn v_cache_mut(&mut self) -> &mut Cache { &mut self.v } pub fn k(&self) -> Result> { self.k.current_data() } pub fn v(&self) -> Result> { self.v.current_data() } pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { self.k.append(k)?; self.v.append(v)?; let out_k = self.k.current_data()?; let out_v = self.v.current_data()?; let k = match out_k { None => { let mut shape = k.dims().to_vec(); shape[self.k.dim] = 0; Tensor::zeros(shape, k.dtype(), k.device())? } Some(k) => k, }; let v = match out_v { None => { let mut shape = v.dims().to_vec(); shape[self.k.dim] = 0; Tensor::zeros(shape, v.dtype(), v.device())? } Some(v) => v, }; Ok((k, v)) } pub fn current_seq_len(&self) -> usize { self.k.current_seq_len() } pub fn reset(&mut self) { self.k.reset(); self.v.reset(); } }