Allow for growing the default KV cache when needed. (#2810)

This commit is contained in:
Laurent Mazare
2025-03-16 17:30:25 +01:00
committed by GitHub
parent cbf5fc80c2
commit 3afb04925a

View File

@ -11,6 +11,7 @@ pub struct Cache {
all_data: Option<Tensor>, all_data: Option<Tensor>,
dim: usize, dim: usize,
current_seq_len: usize, current_seq_len: usize,
grow_by: usize,
max_seq_len: usize, max_seq_len: usize,
} }
@ -20,6 +21,7 @@ impl Cache {
all_data: None, all_data: None,
dim, dim,
current_seq_len: 0, current_seq_len: 0,
grow_by: max_seq_len,
max_seq_len, max_seq_len,
} }
} }
@ -65,11 +67,11 @@ impl Cache {
}; };
let ad = self.all_data.as_mut().unwrap(); let ad = self.all_data.as_mut().unwrap();
if self.current_seq_len + seq_len > self.max_seq_len { if self.current_seq_len + seq_len > self.max_seq_len {
candle::bail!( let mut shape = src.dims().to_vec();
"kv-cache: above max-seq-len {}+{seq_len}>{}", shape[self.dim] = self.grow_by;
self.current_seq_len, let next_ad = Tensor::zeros(shape, src.dtype(), src.device())?;
self.max_seq_len *ad = Tensor::cat(&[&*ad, &next_ad], self.dim)?;
) self.max_seq_len += self.grow_by;
} }
ad.slice_set(src, self.dim, self.current_seq_len)?; ad.slice_set(src, self.dim, self.current_seq_len)?;
self.current_seq_len += seq_len; self.current_seq_len += seq_len;