diff --git a/candle-nn/src/kv_cache.rs b/candle-nn/src/kv_cache.rs index 918dca70..f0be71e1 100644 --- a/candle-nn/src/kv_cache.rs +++ b/candle-nn/src/kv_cache.rs @@ -11,6 +11,7 @@ pub struct Cache { all_data: Option, dim: usize, current_seq_len: usize, + grow_by: usize, max_seq_len: usize, } @@ -20,6 +21,7 @@ impl Cache { all_data: None, dim, current_seq_len: 0, + grow_by: max_seq_len, max_seq_len, } } @@ -65,11 +67,11 @@ impl Cache { }; let ad = self.all_data.as_mut().unwrap(); if self.current_seq_len + seq_len > self.max_seq_len { - candle::bail!( - "kv-cache: above max-seq-len {}+{seq_len}>{}", - self.current_seq_len, - self.max_seq_len - ) + let mut shape = src.dims().to_vec(); + shape[self.dim] = self.grow_by; + let next_ad = Tensor::zeros(shape, src.dtype(), src.device())?; + *ad = Tensor::cat(&[&*ad, &next_ad], self.dim)?; + self.max_seq_len += self.grow_by; } ad.slice_set(src, self.dim, self.current_seq_len)?; self.current_seq_len += seq_len;