From 6547c4bfc3af054d78a0ce7b44a4c681d1c2abf0 Mon Sep 17 00:00:00 2001 From: Laurent Date: Sun, 22 Sep 2024 12:51:07 +0200 Subject: [PATCH] More kv-cache testing. --- candle-nn/src/kv_cache.rs | 2 +- candle-nn/tests/kv_cache.rs | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/candle-nn/src/kv_cache.rs b/candle-nn/src/kv_cache.rs index 422bf112..02651cb0 100644 --- a/candle-nn/src/kv_cache.rs +++ b/candle-nn/src/kv_cache.rs @@ -228,7 +228,7 @@ impl RotatingCache { self.offset = 0; } else { let rem_len = self.max_seq_len - self.offset; - if rem_len <= seq_len { + if seq_len <= rem_len { ad.slice_set(src, self.dim, self.offset)?; self.offset = (self.offset + seq_len) % self.max_seq_len; } else { diff --git a/candle-nn/tests/kv_cache.rs b/candle-nn/tests/kv_cache.rs index cc016ff1..91167b77 100644 --- a/candle-nn/tests/kv_cache.rs +++ b/candle-nn/tests/kv_cache.rs @@ -30,3 +30,29 @@ fn kv_cache() -> Result<()> { } Ok(()) } + +#[test] +fn rotating_kv_cache() -> Result<()> { + let mut cache = candle_nn::kv_cache::RotatingCache::new(0, 6); + for _ in [0, 1] { + assert_eq!(cache.current_seq_len(), 0); + let data = cache.current_data()?; + assert!(data.is_none()); + let t = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?; + cache.append(&t)?; + let data = cache.current_data()?.unwrap(); + assert_eq!(data.to_vec1::()?, [1., 2., 3.]); + let t = Tensor::new(&[4f32], &Device::Cpu)?; + cache.append(&t)?; + let data = cache.current_data()?.unwrap(); + assert_eq!(data.to_vec1::()?, [1., 2., 3., 4.]); + let t = Tensor::new(&[0f32, 5., 6., 7.], &Device::Cpu)?; + cache.append(&t)?; + let data = cache.current_data()?.unwrap(); + assert_eq!(data.to_vec1::()?, [6., 7., 3., 4., 0., 5.]); + assert_eq!(cache.current_seq_len(), 8); + assert_eq!(cache.offset(), 2); + cache.reset(); + } + Ok(()) +}