Test the reset too.

This commit is contained in:
Laurent
2024-09-22 12:46:18 +02:00
parent 1bddd44cb8
commit f9579f80be

View File

@ -9,19 +9,24 @@ use candle::{Device, Result, Tensor};
#[test] #[test]
fn kv_cache() -> Result<()> { fn kv_cache() -> Result<()> {
let mut cache = candle_nn::kv_cache::Cache::new(0, 16); let mut cache = candle_nn::kv_cache::Cache::new(0, 16);
let data = cache.current_data()?; for _ in [0, 1] {
assert!(data.is_none()); assert_eq!(cache.current_seq_len(), 0);
let t = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?; let data = cache.current_data()?;
cache.append(&t)?; assert!(data.is_none());
let data = cache.current_data()?.unwrap(); let t = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3.]); cache.append(&t)?;
let t = Tensor::new(&[4f32], &Device::Cpu)?; let data = cache.current_data()?.unwrap();
cache.append(&t)?; assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3.]);
let data = cache.current_data()?.unwrap(); let t = Tensor::new(&[4f32], &Device::Cpu)?;
assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3., 4.]); cache.append(&t)?;
let t = Tensor::new(&[0f32, 5., 6., 7.], &Device::Cpu)?; let data = cache.current_data()?.unwrap();
cache.append(&t)?; assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3., 4.]);
let data = cache.current_data()?.unwrap(); let t = Tensor::new(&[0f32, 5., 6., 7.], &Device::Cpu)?;
assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3., 4., 0., 5., 6., 7.]); cache.append(&t)?;
let data = cache.current_data()?.unwrap();
assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3., 4., 0., 5., 6., 7.]);
assert_eq!(cache.current_seq_len(), 8);
cache.reset();
}
Ok(()) Ok(())
} }