From 8601537e31af610c0bbd32ee8c8ee17ed802427c Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 22 Sep 2023 12:18:16 +0100 Subject: [PATCH] Add slice-scatter. (#927) * Add slice-scatter. * Add the op. * Make transpose be a no-op when the dimensions are identical. * Add the backprop. * And add some gradient test. --- candle-core/src/backprop.rs | 12 +++++- candle-core/src/op.rs | 1 + candle-core/src/tensor.rs | 71 +++++++++++++++++++++++++++++++ candle-core/tests/grad_tests.rs | 16 +++++++ candle-core/tests/tensor_tests.rs | 43 +++++++++++++++++++ 5 files changed, 142 insertions(+), 1 deletion(-) diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index a2548198..67207dce 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -69,7 +69,8 @@ impl Tensor { | Op::Binary(lhs, rhs, _) | Op::Gather(lhs, rhs, _) | Op::IndexSelect(lhs, rhs, _) - | Op::Matmul(lhs, rhs) => { + | Op::Matmul(lhs, rhs) + | Op::SliceScatter0(lhs, rhs, _) => { let (tg, nodes) = walk(lhs, nodes, already_seen); track_grad |= tg; let (tg, nodes) = walk(rhs, nodes, already_seen); @@ -270,6 +271,15 @@ impl Tensor { Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { op: "upsample-nearest2d", })?, + Op::SliceScatter0(lhs, rhs, start_rhs) => { + let rhs_sum_grad = grads.or_insert(rhs)?; + let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?; + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + + let lhs_sum_grad = grads.or_insert(lhs)?; + let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)? + } Op::Gather(arg, indexes, dim) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?; diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 4882a205..3083d2c8 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -133,6 +133,7 @@ pub enum Op { Copy(Tensor), Broadcast(Tensor), Narrow(Tensor, usize, usize, usize), + SliceScatter0(Tensor, Tensor, usize), Reshape(Tensor), ToDevice(Tensor), Transpose(Tensor, usize, usize), diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 9dccf2b5..d3337e16 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1132,6 +1132,74 @@ impl Tensor { Ok(from_storage(storage, self.shape(), op, false)) } + /// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension. + pub fn slice_scatter(&self, src: &Self, dim: usize, start: usize) -> Result { + let dim = dim.to_index(self.shape(), "slice-scatter")?; + if dim == 0 { + self.slice_scatter0(src, start) + } else { + // TODO: Maybe we want to add a more efficient implementation at some point. + self.transpose(0, dim)? + .slice_scatter0(&src.transpose(0, dim)?, start)? + .transpose(0, dim) + } + } + + /// Embeds the values of the `src` tensor into the `self` tensor on the first dimension. + pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result { + if self.dtype() != src.dtype() { + Err(Error::DTypeMismatchBinaryOp { + lhs: self.dtype(), + rhs: src.dtype(), + op: "slice-scatter", + } + .bt())? + } + if self.device().location() != src.device.location() { + Err(Error::DeviceMismatchBinaryOp { + lhs: self.device().location(), + rhs: src.device().location(), + op: "slice-scatter", + } + .bt())? + } + if self.rank() != src.rank() { + Err(Error::UnexpectedNumberOfDims { + expected: self.rank(), + got: src.rank(), + shape: src.shape().clone(), + } + .bt())? + } + let shape_ok = + self.dims() + .iter() + .zip(src.dims().iter()) + .enumerate() + .all(|(dim_idx, (&d1, &d2))| { + if 0 == dim_idx { + d2 + start <= d1 + } else { + d1 == d2 + } + }); + if !shape_ok { + Err(Error::ShapeMismatchBinaryOp { + op: "slice-scatter (self, src)", + lhs: self.shape().clone(), + rhs: src.shape().clone(), + })? + } + let mut storage = self.device().zeros(self.shape(), self.dtype())?; + self.storage() + .copy_strided_src(&mut storage, 0, self.layout())?; + let offset = start * src.dims()[1..].iter().product::(); + src.storage() + .copy_strided_src(&mut storage, offset, src.layout())?; + let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start)); + Ok(from_storage(storage, self.shape(), op, false)) + } + /// Accumulate element from `source` at indexes `indexes` and add them to `self`. pub fn index_add(&self, indexes: &Self, source: &Self, dim: D) -> Result { let dim = dim.to_index(self.shape(), "index-add")?; @@ -1548,6 +1616,9 @@ impl Tensor { pub fn transpose(&self, dim1: D1, dim2: D2) -> Result { let dim1 = dim1.to_index(self.shape(), "transpose")?; let dim2 = dim2.to_index(self.shape(), "transpose")?; + if dim1 == dim2 { + return Ok(self.clone()); + } let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2)); let tensor_ = Tensor_ { id: TensorId::new(), diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index ad09c90f..2a70cfc4 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -218,6 +218,22 @@ fn binary_grad(device: &Device) -> Result<()> { let grad_x = grads.get(x).context("no grad for x")?; assert_eq!(y.to_vec1::()?, [3., 1., -4., -1.]); assert_eq!(grad_x.to_vec1::()?, [1., 1., 1., 1.]); + + let x_var = Var::new(&[3f32, 1., -4., -1., 5., 9.], device)?; + let x = x_var.as_tensor(); + let y_var = Var::new(&[2f32, 7., 1.], device)?; + let y = y_var.as_tensor(); + + let ss = x + .reshape((2, 3))? + .slice_scatter0(&y.reshape((1, 3))?, 1)? + .sqr()?; + let grads = ss.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + let grad_y = grads.get(y).context("no grad for y")?; + assert_eq!(ss.to_vec2::()?, [[9., 1., 16.], [4., 49., 1.]]); + assert_eq!(grad_x.to_vec1::()?, [6.0, 2.0, -8.0, 0.0, 0.0, 0.0]); + assert_eq!(grad_y.to_vec1::()?, [4.0, 14.0, 2.0]); Ok(()) } diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index edd0bd79..dbe0dd6a 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -674,6 +674,48 @@ fn index_add(device: &Device) -> Result<()> { Ok(()) } +fn slice_scatter(device: &Device) -> Result<()> { + let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?; + assert_eq!( + t.to_vec2::()?, + &[ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0] + ] + ); + let src = Tensor::arange(100f32, 106f32, device)?.reshape((2, 3))?; + assert_eq!( + t.slice_scatter0(&src, 0)?.to_vec2::()?, + &[ + [100.0, 101.0, 102.0], + [103.0, 104.0, 105.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0] + ] + ); + assert_eq!( + t.slice_scatter0(&src, 1)?.to_vec2::()?, + &[ + [0.0, 1.0, 2.0], + [100.0, 101.0, 102.0], + [103.0, 104.0, 105.0], + [9.0, 10.0, 11.0] + ] + ); + assert_eq!( + t.slice_scatter0(&src, 2)?.to_vec2::()?, + &[ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [100.0, 101.0, 102.0], + [103.0, 104.0, 105.0], + ] + ); + Ok(()) +} + fn scatter_add(device: &Device) -> Result<()> { let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?; assert_eq!( @@ -946,6 +988,7 @@ test_device!(index_select, index_select_cpu, index_select_gpu); test_device!(index_add, index_add_cpu, index_add_gpu); test_device!(gather, gather_cpu, gather_gpu); test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu); +test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu); test_device!(randn, randn_cpu, randn_gpu); test_device!(clamp, clamp_cpu, clamp_gpu);