Add a RotatingKVCache.

This commit is contained in:
Laurent
2024-09-22 12:31:25 +02:00
parent c2fca0ca11
commit 9cfe3c7141

View File

@ -145,3 +145,179 @@ impl KvCache {
self.v.reset();
}
}
#[derive(Debug, Clone)]
pub struct RotatingCache {
all_data: Option<Tensor>,
dim: usize,
// `offset` is the current write index in the buffer
offset: usize,
// The total size of the sequence seen so far.
current_seq_len: usize,
// max_seq_len is the size of the rotating buffer, it is actually allowed for the full
// sequence to grow past this limit.
max_seq_len: usize,
}
impl RotatingCache {
pub fn new(dim: usize, max_seq_len: usize) -> Self {
Self {
all_data: None,
dim,
offset: 0,
current_seq_len: 0,
max_seq_len,
}
}
pub fn offset(&self) -> usize {
self.offset
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn current_seq_len(&self) -> usize {
self.current_seq_len
}
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
pub fn all_data(&self) -> &Option<Tensor> {
&self.all_data
}
pub fn current_data(&self) -> Result<Option<Tensor>> {
let data = match self.all_data.as_ref() {
None => None,
Some(d) => {
if self.current_seq_len >= self.max_seq_len {
Some(d.clone())
} else {
Some(d.narrow(self.dim, 0, self.current_seq_len)?)
}
}
};
Ok(data)
}
pub fn reset(&mut self) {
self.offset = 0;
self.current_seq_len = 0;
self.all_data = None;
}
pub fn append(&mut self, src: &Tensor) -> Result<()> {
let seq_len = src.dim(self.dim)?;
// This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
// self.all_data.get_or_insert_with.
if self.all_data.is_none() {
let mut shape = src.dims().to_vec();
shape[self.dim] = self.max_seq_len;
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
self.all_data = Some(ad)
};
let ad = self.all_data.as_mut().unwrap();
if seq_len >= self.max_seq_len {
let src = src.narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?;
ad.slice_set(&src, self.dim, 0)?;
self.offset = 0;
} else {
let rem_len = self.max_seq_len - self.offset;
if rem_len <= seq_len {
ad.slice_set(src, self.dim, self.offset)?;
self.offset = (self.offset + seq_len) % self.max_seq_len;
} else {
// We have to make two copies here as we go over the boundary of the cache.
if rem_len > 0 {
let src1 = src.narrow(self.dim, 0, rem_len)?;
ad.slice_set(&src1, self.dim, self.offset)?;
}
let src2 = src.narrow(self.dim, rem_len, seq_len - rem_len)?;
ad.slice_set(&src2, self.dim, 0)?;
self.offset = seq_len - rem_len;
}
}
self.current_seq_len += seq_len;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct RotatingKvCache {
k: RotatingCache,
v: RotatingCache,
}
impl RotatingKvCache {
pub fn new(dim: usize, max_seq_len: usize) -> Self {
let k = RotatingCache::new(dim, max_seq_len);
let v = RotatingCache::new(dim, max_seq_len);
Self { k, v }
}
pub fn k_cache(&self) -> &RotatingCache {
&self.k
}
pub fn v_cache(&self) -> &RotatingCache {
&self.v
}
pub fn k_cache_mut(&mut self) -> &mut RotatingCache {
&mut self.k
}
pub fn v_cache_mut(&mut self) -> &mut RotatingCache {
&mut self.v
}
pub fn k(&self) -> Result<Option<Tensor>> {
self.k.current_data()
}
pub fn v(&self) -> Result<Option<Tensor>> {
self.v.current_data()
}
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
self.k.append(k)?;
self.v.append(v)?;
let out_k = self.k.current_data()?;
let out_v = self.v.current_data()?;
let k = match out_k {
None => {
let mut shape = k.dims().to_vec();
shape[self.k.dim] = 0;
Tensor::zeros(shape, k.dtype(), k.device())?
}
Some(k) => k,
};
let v = match out_v {
None => {
let mut shape = v.dims().to_vec();
shape[self.k.dim] = 0;
Tensor::zeros(shape, v.dtype(), v.device())?
}
Some(v) => v,
};
Ok((k, v))
}
pub fn offset(&self) -> usize {
self.k.offset()
}
pub fn current_seq_len(&self) -> usize {
self.k.current_seq_len()
}
pub fn reset(&mut self) {
self.k.reset();
self.v.reset();
}
}