From fc877920cecb1736c43970c7def4239d8daa8082 Mon Sep 17 00:00:00 2001 From: Laurent Date: Sun, 22 Sep 2024 12:57:46 +0200 Subject: [PATCH] More tests for the rotating kv-cache. --- candle-nn/tests/kv_cache.rs | 49 ++++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/candle-nn/tests/kv_cache.rs b/candle-nn/tests/kv_cache.rs index 91167b77..3fc40a57 100644 --- a/candle-nn/tests/kv_cache.rs +++ b/candle-nn/tests/kv_cache.rs @@ -35,23 +35,60 @@ fn kv_cache() -> Result<()> { fn rotating_kv_cache() -> Result<()> { let mut cache = candle_nn::kv_cache::RotatingCache::new(0, 6); for _ in [0, 1] { + assert_eq!(cache.offset(), 0); assert_eq!(cache.current_seq_len(), 0); let data = cache.current_data()?; assert!(data.is_none()); - let t = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?; + let t = Tensor::new(&[1., 2., 3.], &Device::Cpu)?; cache.append(&t)?; let data = cache.current_data()?.unwrap(); - assert_eq!(data.to_vec1::()?, [1., 2., 3.]); - let t = Tensor::new(&[4f32], &Device::Cpu)?; + assert_eq!(data.to_vec1::()?, [1., 2., 3.]); + let t = Tensor::new(&[4.], &Device::Cpu)?; cache.append(&t)?; let data = cache.current_data()?.unwrap(); - assert_eq!(data.to_vec1::()?, [1., 2., 3., 4.]); - let t = Tensor::new(&[0f32, 5., 6., 7.], &Device::Cpu)?; + assert_eq!(data.to_vec1::()?, [1., 2., 3., 4.]); + let t = Tensor::new(&[0., 5., 6., 7.], &Device::Cpu)?; cache.append(&t)?; let data = cache.current_data()?.unwrap(); - assert_eq!(data.to_vec1::()?, [6., 7., 3., 4., 0., 5.]); + assert_eq!(data.to_vec1::()?, [6., 7., 3., 4., 0., 5.]); assert_eq!(cache.current_seq_len(), 8); assert_eq!(cache.offset(), 2); + + let t = Tensor::new(&[8.], &Device::Cpu)?; + cache.append(&t)?; + let data = cache.current_data()?.unwrap(); + assert_eq!(data.to_vec1::()?, [6., 7., 8., 4., 0., 5.]); + assert_eq!(cache.current_seq_len(), 9); + assert_eq!(cache.offset(), 3); + + let t = Tensor::new(&[9., 10., 11.], &Device::Cpu)?; + cache.append(&t)?; + let data = cache.current_data()?.unwrap(); + assert_eq!(data.to_vec1::()?, [6., 7., 8., 9., 10., 11.]); + assert_eq!(cache.current_seq_len(), 12); + assert_eq!(cache.offset(), 0); + + let t = Tensor::new(&[12.], &Device::Cpu)?; + cache.append(&t)?; + let data = cache.current_data()?.unwrap(); + assert_eq!(data.to_vec1::()?, [12., 7., 8., 9., 10., 11.]); + assert_eq!(cache.current_seq_len(), 13); + assert_eq!(cache.offset(), 1); + + let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?; + cache.append(&t)?; + let data = cache.current_data()?.unwrap(); + assert_eq!(data.to_vec1::()?, [3., 4., 5., 6., 7., 8.]); + assert_eq!(cache.current_seq_len(), 22); + assert_eq!(cache.offset(), 0); + + let t = Tensor::new(&[42.], &Device::Cpu)?; + cache.append(&t)?; + let data = cache.current_data()?.unwrap(); + assert_eq!(data.to_vec1::()?, [42., 4., 5., 6., 7., 8.]); + assert_eq!(cache.current_seq_len(), 23); + assert_eq!(cache.offset(), 1); + cache.reset(); } Ok(())