diff --git a/candle-nn/src/kv_cache.rs b/candle-nn/src/kv_cache.rs index 02651cb0..59ff5648 100644 --- a/candle-nn/src/kv_cache.rs +++ b/candle-nn/src/kv_cache.rs @@ -210,7 +210,7 @@ impl RotatingCache { self.all_data = None; } - pub fn append(&mut self, src: &Tensor) -> Result<()> { + pub fn append(&mut self, src: &Tensor) -> Result { let seq_len = src.dim(self.dim)?; // This doesn't seem very idiomatic but because the creation can fail, it's tricky to use // self.all_data.get_or_insert_with. @@ -222,10 +222,13 @@ impl RotatingCache { }; let ad = self.all_data.as_mut().unwrap(); + self.current_seq_len += seq_len; if seq_len >= self.max_seq_len { let src = src.narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?; ad.slice_set(&src, self.dim, 0)?; self.offset = 0; + // Here we return `src` rather than `ad` so that all the past can be used. + Ok(src) } else { let rem_len = self.max_seq_len - self.offset; if seq_len <= rem_len { @@ -241,9 +244,12 @@ impl RotatingCache { ad.slice_set(&src2, self.dim, 0)?; self.offset = seq_len - rem_len; } + if self.current_seq_len >= self.max_seq_len { + Ok(ad.clone()) + } else { + Ok(ad.narrow(self.dim, 0, self.current_seq_len)?) + } } - self.current_seq_len += seq_len; - Ok(()) } } @@ -285,27 +291,9 @@ impl RotatingKvCache { } pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { - self.k.append(k)?; - self.v.append(v)?; - let out_k = self.k.current_data()?; - let out_v = self.v.current_data()?; - let k = match out_k { - None => { - let mut shape = k.dims().to_vec(); - shape[self.k.dim] = 0; - Tensor::zeros(shape, k.dtype(), k.device())? - } - Some(k) => k, - }; - let v = match out_v { - None => { - let mut shape = v.dims().to_vec(); - shape[self.k.dim] = 0; - Tensor::zeros(shape, v.dtype(), v.device())? - } - Some(v) => v, - }; - Ok((k, v)) + let out_k = self.k.append(k)?; + let out_v = self.v.append(v)?; + Ok((out_k, out_v)) } pub fn offset(&self) -> usize { diff --git a/candle-nn/tests/kv_cache.rs b/candle-nn/tests/kv_cache.rs index 3fc40a57..2f70f3d4 100644 --- a/candle-nn/tests/kv_cache.rs +++ b/candle-nn/tests/kv_cache.rs @@ -40,51 +40,43 @@ fn rotating_kv_cache() -> Result<()> { let data = cache.current_data()?; assert!(data.is_none()); let t = Tensor::new(&[1., 2., 3.], &Device::Cpu)?; - cache.append(&t)?; - let data = cache.current_data()?.unwrap(); + let data = cache.append(&t)?; assert_eq!(data.to_vec1::()?, [1., 2., 3.]); let t = Tensor::new(&[4.], &Device::Cpu)?; - cache.append(&t)?; - let data = cache.current_data()?.unwrap(); + let data = cache.append(&t)?; assert_eq!(data.to_vec1::()?, [1., 2., 3., 4.]); let t = Tensor::new(&[0., 5., 6., 7.], &Device::Cpu)?; - cache.append(&t)?; - let data = cache.current_data()?.unwrap(); + let data = cache.append(&t)?; assert_eq!(data.to_vec1::()?, [6., 7., 3., 4., 0., 5.]); assert_eq!(cache.current_seq_len(), 8); assert_eq!(cache.offset(), 2); let t = Tensor::new(&[8.], &Device::Cpu)?; - cache.append(&t)?; - let data = cache.current_data()?.unwrap(); + let data = cache.append(&t)?; assert_eq!(data.to_vec1::()?, [6., 7., 8., 4., 0., 5.]); assert_eq!(cache.current_seq_len(), 9); assert_eq!(cache.offset(), 3); let t = Tensor::new(&[9., 10., 11.], &Device::Cpu)?; - cache.append(&t)?; - let data = cache.current_data()?.unwrap(); + let data = cache.append(&t)?; assert_eq!(data.to_vec1::()?, [6., 7., 8., 9., 10., 11.]); assert_eq!(cache.current_seq_len(), 12); assert_eq!(cache.offset(), 0); let t = Tensor::new(&[12.], &Device::Cpu)?; - cache.append(&t)?; - let data = cache.current_data()?.unwrap(); + let data = cache.append(&t)?; assert_eq!(data.to_vec1::()?, [12., 7., 8., 9., 10., 11.]); assert_eq!(cache.current_seq_len(), 13); assert_eq!(cache.offset(), 1); let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?; - cache.append(&t)?; - let data = cache.current_data()?.unwrap(); + let data = cache.append(&t)?; assert_eq!(data.to_vec1::()?, [3., 4., 5., 6., 7., 8.]); assert_eq!(cache.current_seq_len(), 22); assert_eq!(cache.offset(), 0); let t = Tensor::new(&[42.], &Device::Cpu)?; - cache.append(&t)?; - let data = cache.current_data()?.unwrap(); + let data = cache.append(&t)?; assert_eq!(data.to_vec1::()?, [42., 4., 5., 6., 7., 8.]); assert_eq!(cache.current_seq_len(), 23); assert_eq!(cache.offset(), 1);