mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Allow for growing the default KV cache when needed. (#2810)
This commit is contained in:
@ -11,6 +11,7 @@ pub struct Cache {
|
||||
all_data: Option<Tensor>,
|
||||
dim: usize,
|
||||
current_seq_len: usize,
|
||||
grow_by: usize,
|
||||
max_seq_len: usize,
|
||||
}
|
||||
|
||||
@ -20,6 +21,7 @@ impl Cache {
|
||||
all_data: None,
|
||||
dim,
|
||||
current_seq_len: 0,
|
||||
grow_by: max_seq_len,
|
||||
max_seq_len,
|
||||
}
|
||||
}
|
||||
@ -65,11 +67,11 @@ impl Cache {
|
||||
};
|
||||
let ad = self.all_data.as_mut().unwrap();
|
||||
if self.current_seq_len + seq_len > self.max_seq_len {
|
||||
candle::bail!(
|
||||
"kv-cache: above max-seq-len {}+{seq_len}>{}",
|
||||
self.current_seq_len,
|
||||
self.max_seq_len
|
||||
)
|
||||
let mut shape = src.dims().to_vec();
|
||||
shape[self.dim] = self.grow_by;
|
||||
let next_ad = Tensor::zeros(shape, src.dtype(), src.device())?;
|
||||
*ad = Tensor::cat(&[&*ad, &next_ad], self.dim)?;
|
||||
self.max_seq_len += self.grow_by;
|
||||
}
|
||||
ad.slice_set(src, self.dim, self.current_seq_len)?;
|
||||
self.current_seq_len += seq_len;
|
||||
|
Reference in New Issue
Block a user