Add some testing for index-add (#237)

* Add some testing for index-add.

* Fix the cpu implementation for index-add.
This commit is contained in:
Laurent Mazare
2023-07-25 08:38:33 +01:00
committed by GitHub
parent 74a6a769dd
commit 18cc73954a
3 changed files with 66 additions and 15 deletions

View File

@ -815,7 +815,9 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
}; };
let dim = self.dim; let dim = self.dim;
let max_idx = l1.dims()[dim]; let max_idx = l1.dims()[dim];
let stride = src_l.stride()[dim]; let pre_dim = src_l.dims()[..dim].iter().product::<usize>();
let src_dim_sz = src_l.dims()[dim];
let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
if dim == 0 { if dim == 0 {
for (src_idx, dst_idx) in self.ids.iter().enumerate() { for (src_idx, dst_idx) in self.ids.iter().enumerate() {
let dst_idx = dst_idx.as_usize(); let dst_idx = dst_idx.as_usize();
@ -826,17 +828,15 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
size: max_idx, size: max_idx,
})? })?
} }
let src_idx = src_idx * stride; let src_idx = src_idx * post_dim;
let dst_idx = dst_idx * stride; let dst_idx = dst_idx * post_dim;
let src = &src[src_idx..src_idx + stride]; let src = &src[src_idx..src_idx + post_dim];
let dst = &mut dst[dst_idx..dst_idx + stride]; let dst = &mut dst[dst_idx..dst_idx + post_dim];
for (d, &s) in dst.iter_mut().zip(src.iter()) { for (d, &s) in dst.iter_mut().zip(src.iter()) {
*d += s *d += s
} }
} }
} else { } else {
let pre_dim = src_l.dims()[..dim].iter().product::<usize>();
let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
for (src_idx, dst_idx) in self.ids.iter().enumerate() { for (src_idx, dst_idx) in self.ids.iter().enumerate() {
let dst_idx = dst_idx.as_usize(); let dst_idx = dst_idx.as_usize();
if dst_idx >= max_idx { if dst_idx >= max_idx {
@ -847,9 +847,8 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
})? })?
} }
for pre_i in 0..pre_dim { for pre_i in 0..pre_dim {
let pre_i = pre_i * stride; let pre_src_i = (pre_i * src_dim_sz + src_idx) * post_dim;
let pre_src_i = (pre_i + src_idx) * post_dim; let pre_dst_i = (pre_i * max_idx + dst_idx) * post_dim;
let pre_dst_i = (pre_i + dst_idx) * post_dim;
let src = &src[pre_src_i..pre_src_i + post_dim]; let src = &src[pre_src_i..pre_src_i + post_dim];
let dst = &mut dst[pre_dst_i..pre_dst_i + post_dim]; let dst = &mut dst[pre_dst_i..pre_dst_i + post_dim];
for (d, &s) in dst.iter_mut().zip(src.iter()) { for (d, &s) in dst.iter_mut().zip(src.iter()) {

View File

@ -1010,18 +1010,26 @@ impl Tensor {
break; break;
} }
} }
// The number of element in indexes must match the dimension on which the add is mismatch
// performed on the source tensor (and the index values from `indexes` are taken from
// the target tensor self)
mismatch || source_dims[dim] != indexes.dims1()?
}; };
if mismatch { if mismatch {
Err(Error::ShapeMismatchBinaryOp { Err(Error::ShapeMismatchBinaryOp {
op: "index-add", op: "index-add (self, source)",
lhs: self.shape().clone(), lhs: self.shape().clone(),
rhs: source.shape().clone(), rhs: source.shape().clone(),
})? })?
} }
// The number of element in indexes must match the dimension on which the add is
// performed on the source tensor (and the index values from `indexes` are taken from
// the target tensor self)
let indexes_len = indexes.dims1()?;
if source_dims[dim] != indexes_len {
Err(Error::ShapeMismatchBinaryOp {
op: "index-add (ids, source))",
lhs: indexes.shape().clone(),
rhs: source.shape().clone(),
})?
}
let storage = self.storage().index_add( let storage = self.storage().index_add(
self.layout(), self.layout(),
&indexes.storage(), &indexes.storage(),

View File

@ -346,6 +346,49 @@ fn index_select(device: &Device) -> Result<()> {
Ok(()) Ok(())
} }
fn index_add(device: &Device) -> Result<()> {
let ids = Tensor::new(&[0u32, 1u32, 1u32], device)?;
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
t.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0]
]
);
let init = Tensor::ones((4, 2), DType::F32, device)?;
let hs = init.index_add(&ids, &t, 1)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[[1.0, 4.0], [4.0, 10.0], [7.0, 16.0], [10.0, 22.0]],
);
let init = Tensor::zeros((4, 2), DType::F32, device)?;
let ids = Tensor::new(&[1u32, 0u32, 0u32], device)?;
let hs = init.index_add(&ids, &t, 1)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[[3.0, 0.0], [9.0, 3.0], [15.0, 6.0], [21.0, 9.0]],
);
let init = Tensor::zeros((6, 3), DType::F32, device)?;
let ids = Tensor::new(&[5u32, 0u32, 1u32, 0u32], device)?;
let hs = init.index_add(&ids, &t, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[12.0, 14.0, 16.0],
[6.0, 7.0, 8.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 1.0, 2.0]
]
);
Ok(())
}
fn gather(device: &Device) -> Result<()> { fn gather(device: &Device) -> Result<()> {
let ids = Tensor::new(&[[0u32], [2u32], [1u32], [0u32]], device)?; let ids = Tensor::new(&[[0u32], [2u32], [1u32], [0u32]], device)?;
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?; let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
@ -543,4 +586,5 @@ test_device!(cmp, cmp_cpu, cmp_gpu);
test_device!(matmul, matmul_cpu, matmul_gpu); test_device!(matmul, matmul_cpu, matmul_gpu);
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu); test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
test_device!(index_select, index_select_cpu, index_select_gpu); test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu); test_device!(gather, gather_cpu, gather_gpu);