mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add a slice_set op. (#2193)
* Add a slice_set op. * Add some testing. * Add the dedicated kv-cache module. * Derive debug and clone. * Expose more kv-cache functions. * Return the current data when appending. * Use the new cache in the quantized phi3 model.
This commit is contained in:
@ -235,4 +235,66 @@ impl Tensor {
|
||||
}
|
||||
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
/// Set the values on `self` using values from `src`. The copy starts at the specified
|
||||
/// `offset` for the target dimension `dim` on `self`.
|
||||
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
|
||||
/// has to be greater than or equal to `offset` plus the `src` size.
|
||||
///
|
||||
/// Note that this modifies `self` in place and as such is not compatibel with
|
||||
/// back-propagation.
|
||||
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
|
||||
let dim = dim.to_index(self.shape(), "slice-set")?;
|
||||
if !self.is_contiguous() || !src.is_contiguous() {
|
||||
Err(Error::RequiresContiguous { op: "slice-set" }.bt())?
|
||||
}
|
||||
if self.dtype() != src.dtype() {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: self.dtype(),
|
||||
rhs: src.dtype(),
|
||||
op: "slice-set",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.device().location() != src.device().location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: self.device().location(),
|
||||
rhs: src.device().location(),
|
||||
op: "slice-set",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.rank() != src.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: self.rank(),
|
||||
got: src.rank(),
|
||||
shape: self.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in self.dims().iter().zip(src.dims().iter()).enumerate() {
|
||||
if dim_idx == dim && *v2 + offset > *v1 {
|
||||
crate::bail!("shape mismatch on target dim, dst: {v1}, src: {v2} + {offset}")
|
||||
}
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
crate::bail!("shape mismatch on dim {dim_idx}, {v1} <> {v2}")
|
||||
}
|
||||
}
|
||||
let block_size: usize = src.dims().iter().skip(1 + dim).product();
|
||||
let d1: usize = src.dims().iter().take(dim).product();
|
||||
let d2 = block_size * src.dims()[dim];
|
||||
let dst_o = self.layout().start_offset() + offset * block_size;
|
||||
let src_o = src.layout().start_offset();
|
||||
src.storage().copy2d(
|
||||
&mut self.storage_mut(),
|
||||
d1,
|
||||
d2,
|
||||
/* src_s */ d2,
|
||||
/* dst_s */ block_size * self.dims()[dim],
|
||||
src_o,
|
||||
dst_o,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -665,6 +665,30 @@ fn broadcast(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn slice_set(device: &Device) -> Result<()> {
|
||||
let (b, h, max_t, d) = (2, 4, 7, 3);
|
||||
let cache = Tensor::zeros((b, h, max_t, d), DType::F32, device)?;
|
||||
let tensor = Tensor::randn(0f32, 1f32, (b, h, 4, d), device)?;
|
||||
cache.slice_set(&tensor, 2, 0)?;
|
||||
let cache_t = cache.narrow(2, 0, 4)?;
|
||||
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
cache.slice_set(&tensor, 2, 1)?;
|
||||
let cache_t = cache.narrow(2, 1, 4)?;
|
||||
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
let ones = Tensor::ones((b, h, 1, d), DType::F32, device)?;
|
||||
cache.slice_set(&ones, 2, 6)?;
|
||||
let diff = cache.narrow(2, 5, 1)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
let diff = (cache.narrow(2, 6, 1)? - 1.)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cat(device: &Device) -> Result<()> {
|
||||
// 1D
|
||||
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||
@ -1146,6 +1170,7 @@ test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
||||
test_device!(slice_set, ss_cpu, ss_gpu, ss_metal);
|
||||
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
||||
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
||||
test_device!(min, min_cpu, min_gpu, min_metal);
|
||||
|
Reference in New Issue
Block a user