Test the reset too.

This commit is contained in:
Laurent
2024-09-22 12:46:18 +02:00
parent 1bddd44cb8
commit f9579f80be

View File

@ -9,6 +9,8 @@ use candle::{Device, Result, Tensor};
#[test] #[test]
fn kv_cache() -> Result<()> { fn kv_cache() -> Result<()> {
let mut cache = candle_nn::kv_cache::Cache::new(0, 16); let mut cache = candle_nn::kv_cache::Cache::new(0, 16);
for _ in [0, 1] {
assert_eq!(cache.current_seq_len(), 0);
let data = cache.current_data()?; let data = cache.current_data()?;
assert!(data.is_none()); assert!(data.is_none());
let t = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?; let t = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
@ -23,5 +25,8 @@ fn kv_cache() -> Result<()> {
cache.append(&t)?; cache.append(&t)?;
let data = cache.current_data()?.unwrap(); let data = cache.current_data()?.unwrap();
assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3., 4., 0., 5., 6., 7.]); assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3., 4., 0., 5., 6., 7.]);
assert_eq!(cache.current_seq_len(), 8);
cache.reset();
}
Ok(()) Ok(())
} }