From 18cc73954a318c46743aa6f6e9748793116ac0a2 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 25 Jul 2023 08:38:33 +0100 Subject: [PATCH] Add some testing for index-add (#237) * Add some testing for index-add. * Fix the cpu implementation for index-add. --- candle-core/src/cpu_backend.rs | 19 +++++++------ candle-core/src/tensor.rs | 18 +++++++++---- candle-core/tests/tensor_tests.rs | 44 +++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 15 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 9a6320ec..8d38b158 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -815,7 +815,9 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> { }; let dim = self.dim; let max_idx = l1.dims()[dim]; - let stride = src_l.stride()[dim]; + let pre_dim = src_l.dims()[..dim].iter().product::(); + let src_dim_sz = src_l.dims()[dim]; + let post_dim = src_l.dims()[dim + 1..].iter().product::(); if dim == 0 { for (src_idx, dst_idx) in self.ids.iter().enumerate() { let dst_idx = dst_idx.as_usize(); @@ -826,17 +828,15 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> { size: max_idx, })? } - let src_idx = src_idx * stride; - let dst_idx = dst_idx * stride; - let src = &src[src_idx..src_idx + stride]; - let dst = &mut dst[dst_idx..dst_idx + stride]; + let src_idx = src_idx * post_dim; + let dst_idx = dst_idx * post_dim; + let src = &src[src_idx..src_idx + post_dim]; + let dst = &mut dst[dst_idx..dst_idx + post_dim]; for (d, &s) in dst.iter_mut().zip(src.iter()) { *d += s } } } else { - let pre_dim = src_l.dims()[..dim].iter().product::(); - let post_dim = src_l.dims()[dim + 1..].iter().product::(); for (src_idx, dst_idx) in self.ids.iter().enumerate() { let dst_idx = dst_idx.as_usize(); if dst_idx >= max_idx { @@ -847,9 +847,8 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> { })? } for pre_i in 0..pre_dim { - let pre_i = pre_i * stride; - let pre_src_i = (pre_i + src_idx) * post_dim; - let pre_dst_i = (pre_i + dst_idx) * post_dim; + let pre_src_i = (pre_i * src_dim_sz + src_idx) * post_dim; + let pre_dst_i = (pre_i * max_idx + dst_idx) * post_dim; let src = &src[pre_src_i..pre_src_i + post_dim]; let dst = &mut dst[pre_dst_i..pre_dst_i + post_dim]; for (d, &s) in dst.iter_mut().zip(src.iter()) { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 5257219d..b83d7b64 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1010,18 +1010,26 @@ impl Tensor { break; } } - // The number of element in indexes must match the dimension on which the add is - // performed on the source tensor (and the index values from `indexes` are taken from - // the target tensor self) - mismatch || source_dims[dim] != indexes.dims1()? + mismatch }; if mismatch { Err(Error::ShapeMismatchBinaryOp { - op: "index-add", + op: "index-add (self, source)", lhs: self.shape().clone(), rhs: source.shape().clone(), })? } + // The number of element in indexes must match the dimension on which the add is + // performed on the source tensor (and the index values from `indexes` are taken from + // the target tensor self) + let indexes_len = indexes.dims1()?; + if source_dims[dim] != indexes_len { + Err(Error::ShapeMismatchBinaryOp { + op: "index-add (ids, source))", + lhs: indexes.shape().clone(), + rhs: source.shape().clone(), + })? + } let storage = self.storage().index_add( self.layout(), &indexes.storage(), diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 501c55ec..356e64d3 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -346,6 +346,49 @@ fn index_select(device: &Device) -> Result<()> { Ok(()) } +fn index_add(device: &Device) -> Result<()> { + let ids = Tensor::new(&[0u32, 1u32, 1u32], device)?; + let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?; + assert_eq!( + t.to_vec2::()?, + &[ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0] + ] + ); + let init = Tensor::ones((4, 2), DType::F32, device)?; + let hs = init.index_add(&ids, &t, 1)?; + assert_eq!( + hs.to_vec2::()?, + &[[1.0, 4.0], [4.0, 10.0], [7.0, 16.0], [10.0, 22.0]], + ); + let init = Tensor::zeros((4, 2), DType::F32, device)?; + let ids = Tensor::new(&[1u32, 0u32, 0u32], device)?; + let hs = init.index_add(&ids, &t, 1)?; + assert_eq!( + hs.to_vec2::()?, + &[[3.0, 0.0], [9.0, 3.0], [15.0, 6.0], [21.0, 9.0]], + ); + + let init = Tensor::zeros((6, 3), DType::F32, device)?; + let ids = Tensor::new(&[5u32, 0u32, 1u32, 0u32], device)?; + let hs = init.index_add(&ids, &t, 0)?; + assert_eq!( + hs.to_vec2::()?, + &[ + [12.0, 14.0, 16.0], + [6.0, 7.0, 8.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 1.0, 2.0] + ] + ); + Ok(()) +} + fn gather(device: &Device) -> Result<()> { let ids = Tensor::new(&[[0u32], [2u32], [1u32], [0u32]], device)?; let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?; @@ -543,4 +586,5 @@ test_device!(cmp, cmp_cpu, cmp_gpu); test_device!(matmul, matmul_cpu, matmul_gpu); test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu); test_device!(index_select, index_select_cpu, index_select_gpu); +test_device!(index_add, index_add_cpu, index_add_gpu); test_device!(gather, gather_cpu, gather_gpu);