Add a slice_set op. (#2193)

* Add a slice_set op.

* Add some testing.

* Add the dedicated kv-cache module.

* Derive debug and clone.

* Expose more kv-cache functions.

* Return the current data when appending.

* Use the new cache in the quantized phi3 model.
This commit is contained in:
Laurent Mazare
2024-05-18 15:58:18 +02:00
committed by GitHub
parent 349c3e806a
commit 01545f7303
6 changed files with 209 additions and 23 deletions

View File

@ -235,4 +235,66 @@ impl Tensor {
}
Ok(crate::tensor::from_storage(storage, shape, op, false))
}
/// Set the values on `self` using values from `src`. The copy starts at the specified
/// `offset` for the target dimension `dim` on `self`.
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
/// has to be greater than or equal to `offset` plus the `src` size.
///
/// Note that this modifies `self` in place and as such is not compatibel with
/// back-propagation.
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
let dim = dim.to_index(self.shape(), "slice-set")?;
if !self.is_contiguous() || !src.is_contiguous() {
Err(Error::RequiresContiguous { op: "slice-set" }.bt())?
}
if self.dtype() != src.dtype() {
Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: src.dtype(),
op: "slice-set",
}
.bt())?
}
if self.device().location() != src.device().location() {
Err(Error::DeviceMismatchBinaryOp {
lhs: self.device().location(),
rhs: src.device().location(),
op: "slice-set",
}
.bt())?
}
if self.rank() != src.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: self.rank(),
got: src.rank(),
shape: self.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in self.dims().iter().zip(src.dims().iter()).enumerate() {
if dim_idx == dim && *v2 + offset > *v1 {
crate::bail!("shape mismatch on target dim, dst: {v1}, src: {v2} + {offset}")
}
if dim_idx != dim && v1 != v2 {
crate::bail!("shape mismatch on dim {dim_idx}, {v1} <> {v2}")
}
}
let block_size: usize = src.dims().iter().skip(1 + dim).product();
let d1: usize = src.dims().iter().take(dim).product();
let d2 = block_size * src.dims()[dim];
let dst_o = self.layout().start_offset() + offset * block_size;
let src_o = src.layout().start_offset();
src.storage().copy2d(
&mut self.storage_mut(),
d1,
d2,
/* src_s */ d2,
/* dst_s */ block_size * self.dims()[dim],
src_o,
dst_o,
)?;
Ok(())
}
}

View File

@ -665,6 +665,30 @@ fn broadcast(device: &Device) -> Result<()> {
Ok(())
}
fn slice_set(device: &Device) -> Result<()> {
let (b, h, max_t, d) = (2, 4, 7, 3);
let cache = Tensor::zeros((b, h, max_t, d), DType::F32, device)?;
let tensor = Tensor::randn(0f32, 1f32, (b, h, 4, d), device)?;
cache.slice_set(&tensor, 2, 0)?;
let cache_t = cache.narrow(2, 0, 4)?;
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
cache.slice_set(&tensor, 2, 1)?;
let cache_t = cache.narrow(2, 1, 4)?;
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let ones = Tensor::ones((b, h, 1, d), DType::F32, device)?;
cache.slice_set(&ones, 2, 6)?;
let diff = cache.narrow(2, 5, 1)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let diff = (cache.narrow(2, 6, 1)? - 1.)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
Ok(())
}
fn cat(device: &Device) -> Result<()> {
// 1D
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
@ -1146,6 +1170,7 @@ test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
test_device!(slice_set, ss_cpu, ss_gpu, ss_metal);
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
test_device!(min, min_cpu, min_gpu, min_metal);