mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Allow for growing the default KV cache when needed. (#2810)
This commit is contained in:
@ -11,6 +11,7 @@ pub struct Cache {
|
|||||||
all_data: Option<Tensor>,
|
all_data: Option<Tensor>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
current_seq_len: usize,
|
current_seq_len: usize,
|
||||||
|
grow_by: usize,
|
||||||
max_seq_len: usize,
|
max_seq_len: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -20,6 +21,7 @@ impl Cache {
|
|||||||
all_data: None,
|
all_data: None,
|
||||||
dim,
|
dim,
|
||||||
current_seq_len: 0,
|
current_seq_len: 0,
|
||||||
|
grow_by: max_seq_len,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -65,11 +67,11 @@ impl Cache {
|
|||||||
};
|
};
|
||||||
let ad = self.all_data.as_mut().unwrap();
|
let ad = self.all_data.as_mut().unwrap();
|
||||||
if self.current_seq_len + seq_len > self.max_seq_len {
|
if self.current_seq_len + seq_len > self.max_seq_len {
|
||||||
candle::bail!(
|
let mut shape = src.dims().to_vec();
|
||||||
"kv-cache: above max-seq-len {}+{seq_len}>{}",
|
shape[self.dim] = self.grow_by;
|
||||||
self.current_seq_len,
|
let next_ad = Tensor::zeros(shape, src.dtype(), src.device())?;
|
||||||
self.max_seq_len
|
*ad = Tensor::cat(&[&*ad, &next_ad], self.dim)?;
|
||||||
)
|
self.max_seq_len += self.grow_by;
|
||||||
}
|
}
|
||||||
ad.slice_set(src, self.dim, self.current_seq_len)?;
|
ad.slice_set(src, self.dim, self.current_seq_len)?;
|
||||||
self.current_seq_len += seq_len;
|
self.current_seq_len += seq_len;
|
||||||
|
Reference in New Issue
Block a user