diff --git a/candle-nn/src/kv_cache.rs b/candle-nn/src/kv_cache.rs index eb5dbfdb..422bf112 100644 --- a/candle-nn/src/kv_cache.rs +++ b/candle-nn/src/kv_cache.rs @@ -145,3 +145,179 @@ impl KvCache { self.v.reset(); } } + +#[derive(Debug, Clone)] +pub struct RotatingCache { + all_data: Option, + dim: usize, + // `offset` is the current write index in the buffer + offset: usize, + // The total size of the sequence seen so far. + current_seq_len: usize, + // max_seq_len is the size of the rotating buffer, it is actually allowed for the full + // sequence to grow past this limit. + max_seq_len: usize, +} + +impl RotatingCache { + pub fn new(dim: usize, max_seq_len: usize) -> Self { + Self { + all_data: None, + dim, + offset: 0, + current_seq_len: 0, + max_seq_len, + } + } + + pub fn offset(&self) -> usize { + self.offset + } + + pub fn dim(&self) -> usize { + self.dim + } + + pub fn current_seq_len(&self) -> usize { + self.current_seq_len + } + + pub fn max_seq_len(&self) -> usize { + self.max_seq_len + } + + pub fn all_data(&self) -> &Option { + &self.all_data + } + + pub fn current_data(&self) -> Result> { + let data = match self.all_data.as_ref() { + None => None, + Some(d) => { + if self.current_seq_len >= self.max_seq_len { + Some(d.clone()) + } else { + Some(d.narrow(self.dim, 0, self.current_seq_len)?) + } + } + }; + Ok(data) + } + + pub fn reset(&mut self) { + self.offset = 0; + self.current_seq_len = 0; + self.all_data = None; + } + + pub fn append(&mut self, src: &Tensor) -> Result<()> { + let seq_len = src.dim(self.dim)?; + // This doesn't seem very idiomatic but because the creation can fail, it's tricky to use + // self.all_data.get_or_insert_with. + if self.all_data.is_none() { + let mut shape = src.dims().to_vec(); + shape[self.dim] = self.max_seq_len; + let ad = Tensor::zeros(shape, src.dtype(), src.device())?; + self.all_data = Some(ad) + }; + let ad = self.all_data.as_mut().unwrap(); + + if seq_len >= self.max_seq_len { + let src = src.narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?; + ad.slice_set(&src, self.dim, 0)?; + self.offset = 0; + } else { + let rem_len = self.max_seq_len - self.offset; + if rem_len <= seq_len { + ad.slice_set(src, self.dim, self.offset)?; + self.offset = (self.offset + seq_len) % self.max_seq_len; + } else { + // We have to make two copies here as we go over the boundary of the cache. + if rem_len > 0 { + let src1 = src.narrow(self.dim, 0, rem_len)?; + ad.slice_set(&src1, self.dim, self.offset)?; + } + let src2 = src.narrow(self.dim, rem_len, seq_len - rem_len)?; + ad.slice_set(&src2, self.dim, 0)?; + self.offset = seq_len - rem_len; + } + } + self.current_seq_len += seq_len; + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub struct RotatingKvCache { + k: RotatingCache, + v: RotatingCache, +} + +impl RotatingKvCache { + pub fn new(dim: usize, max_seq_len: usize) -> Self { + let k = RotatingCache::new(dim, max_seq_len); + let v = RotatingCache::new(dim, max_seq_len); + Self { k, v } + } + + pub fn k_cache(&self) -> &RotatingCache { + &self.k + } + + pub fn v_cache(&self) -> &RotatingCache { + &self.v + } + + pub fn k_cache_mut(&mut self) -> &mut RotatingCache { + &mut self.k + } + + pub fn v_cache_mut(&mut self) -> &mut RotatingCache { + &mut self.v + } + + pub fn k(&self) -> Result> { + self.k.current_data() + } + + pub fn v(&self) -> Result> { + self.v.current_data() + } + + pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { + self.k.append(k)?; + self.v.append(v)?; + let out_k = self.k.current_data()?; + let out_v = self.v.current_data()?; + let k = match out_k { + None => { + let mut shape = k.dims().to_vec(); + shape[self.k.dim] = 0; + Tensor::zeros(shape, k.dtype(), k.device())? + } + Some(k) => k, + }; + let v = match out_v { + None => { + let mut shape = v.dims().to_vec(); + shape[self.k.dim] = 0; + Tensor::zeros(shape, v.dtype(), v.device())? + } + Some(v) => v, + }; + Ok((k, v)) + } + + pub fn offset(&self) -> usize { + self.k.offset() + } + + pub fn current_seq_len(&self) -> usize { + self.k.current_seq_len() + } + + pub fn reset(&mut self) { + self.k.reset(); + self.v.reset(); + } +}