mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Improve the api for the rotating cache so that the whole src tensor gets returned when it's overlarge.
This commit is contained in:
@ -210,7 +210,7 @@ impl RotatingCache {
|
||||
self.all_data = None;
|
||||
}
|
||||
|
||||
pub fn append(&mut self, src: &Tensor) -> Result<()> {
|
||||
pub fn append(&mut self, src: &Tensor) -> Result<Tensor> {
|
||||
let seq_len = src.dim(self.dim)?;
|
||||
// This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
|
||||
// self.all_data.get_or_insert_with.
|
||||
@ -222,10 +222,13 @@ impl RotatingCache {
|
||||
};
|
||||
let ad = self.all_data.as_mut().unwrap();
|
||||
|
||||
self.current_seq_len += seq_len;
|
||||
if seq_len >= self.max_seq_len {
|
||||
let src = src.narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?;
|
||||
ad.slice_set(&src, self.dim, 0)?;
|
||||
self.offset = 0;
|
||||
// Here we return `src` rather than `ad` so that all the past can be used.
|
||||
Ok(src)
|
||||
} else {
|
||||
let rem_len = self.max_seq_len - self.offset;
|
||||
if seq_len <= rem_len {
|
||||
@ -241,9 +244,12 @@ impl RotatingCache {
|
||||
ad.slice_set(&src2, self.dim, 0)?;
|
||||
self.offset = seq_len - rem_len;
|
||||
}
|
||||
if self.current_seq_len >= self.max_seq_len {
|
||||
Ok(ad.clone())
|
||||
} else {
|
||||
Ok(ad.narrow(self.dim, 0, self.current_seq_len)?)
|
||||
}
|
||||
}
|
||||
self.current_seq_len += seq_len;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@ -285,27 +291,9 @@ impl RotatingKvCache {
|
||||
}
|
||||
|
||||
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
self.k.append(k)?;
|
||||
self.v.append(v)?;
|
||||
let out_k = self.k.current_data()?;
|
||||
let out_v = self.v.current_data()?;
|
||||
let k = match out_k {
|
||||
None => {
|
||||
let mut shape = k.dims().to_vec();
|
||||
shape[self.k.dim] = 0;
|
||||
Tensor::zeros(shape, k.dtype(), k.device())?
|
||||
}
|
||||
Some(k) => k,
|
||||
};
|
||||
let v = match out_v {
|
||||
None => {
|
||||
let mut shape = v.dims().to_vec();
|
||||
shape[self.k.dim] = 0;
|
||||
Tensor::zeros(shape, v.dtype(), v.device())?
|
||||
}
|
||||
Some(v) => v,
|
||||
};
|
||||
Ok((k, v))
|
||||
let out_k = self.k.append(k)?;
|
||||
let out_v = self.v.append(v)?;
|
||||
Ok((out_k, out_v))
|
||||
}
|
||||
|
||||
pub fn offset(&self) -> usize {
|
||||
|
@ -40,51 +40,43 @@ fn rotating_kv_cache() -> Result<()> {
|
||||
let data = cache.current_data()?;
|
||||
assert!(data.is_none());
|
||||
let t = Tensor::new(&[1., 2., 3.], &Device::Cpu)?;
|
||||
cache.append(&t)?;
|
||||
let data = cache.current_data()?.unwrap();
|
||||
let data = cache.append(&t)?;
|
||||
assert_eq!(data.to_vec1::<f64>()?, [1., 2., 3.]);
|
||||
let t = Tensor::new(&[4.], &Device::Cpu)?;
|
||||
cache.append(&t)?;
|
||||
let data = cache.current_data()?.unwrap();
|
||||
let data = cache.append(&t)?;
|
||||
assert_eq!(data.to_vec1::<f64>()?, [1., 2., 3., 4.]);
|
||||
let t = Tensor::new(&[0., 5., 6., 7.], &Device::Cpu)?;
|
||||
cache.append(&t)?;
|
||||
let data = cache.current_data()?.unwrap();
|
||||
let data = cache.append(&t)?;
|
||||
assert_eq!(data.to_vec1::<f64>()?, [6., 7., 3., 4., 0., 5.]);
|
||||
assert_eq!(cache.current_seq_len(), 8);
|
||||
assert_eq!(cache.offset(), 2);
|
||||
|
||||
let t = Tensor::new(&[8.], &Device::Cpu)?;
|
||||
cache.append(&t)?;
|
||||
let data = cache.current_data()?.unwrap();
|
||||
let data = cache.append(&t)?;
|
||||
assert_eq!(data.to_vec1::<f64>()?, [6., 7., 8., 4., 0., 5.]);
|
||||
assert_eq!(cache.current_seq_len(), 9);
|
||||
assert_eq!(cache.offset(), 3);
|
||||
|
||||
let t = Tensor::new(&[9., 10., 11.], &Device::Cpu)?;
|
||||
cache.append(&t)?;
|
||||
let data = cache.current_data()?.unwrap();
|
||||
let data = cache.append(&t)?;
|
||||
assert_eq!(data.to_vec1::<f64>()?, [6., 7., 8., 9., 10., 11.]);
|
||||
assert_eq!(cache.current_seq_len(), 12);
|
||||
assert_eq!(cache.offset(), 0);
|
||||
|
||||
let t = Tensor::new(&[12.], &Device::Cpu)?;
|
||||
cache.append(&t)?;
|
||||
let data = cache.current_data()?.unwrap();
|
||||
let data = cache.append(&t)?;
|
||||
assert_eq!(data.to_vec1::<f64>()?, [12., 7., 8., 9., 10., 11.]);
|
||||
assert_eq!(cache.current_seq_len(), 13);
|
||||
assert_eq!(cache.offset(), 1);
|
||||
|
||||
let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?;
|
||||
cache.append(&t)?;
|
||||
let data = cache.current_data()?.unwrap();
|
||||
let data = cache.append(&t)?;
|
||||
assert_eq!(data.to_vec1::<f64>()?, [3., 4., 5., 6., 7., 8.]);
|
||||
assert_eq!(cache.current_seq_len(), 22);
|
||||
assert_eq!(cache.offset(), 0);
|
||||
|
||||
let t = Tensor::new(&[42.], &Device::Cpu)?;
|
||||
cache.append(&t)?;
|
||||
let data = cache.current_data()?.unwrap();
|
||||
let data = cache.append(&t)?;
|
||||
assert_eq!(data.to_vec1::<f64>()?, [42., 4., 5., 6., 7., 8.]);
|
||||
assert_eq!(cache.current_seq_len(), 23);
|
||||
assert_eq!(cache.offset(), 1);
|
||||
|
Reference in New Issue
Block a user