diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 59e3a847..b92fd931 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -178,18 +178,22 @@ impl CpuStorage { dst_dims[sum_dim] = 1; } let dst_shape = Shape::from(dst_dims); - let sum_dims_and_stride: Vec<_> = src_dims + let mut sum_dims = sum_dims.to_vec(); + // Sort the sum_dims as they have to be processed from left to right when converting the + // indexes. + sum_dims.sort(); + let sum_dims_and_stride: Vec<_> = sum_dims .iter() - .enumerate() - .map(|(i, d)| (d, src_dims[i + 1..].iter().product::())) + .map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::())) .collect(); let to_dst_index = |unstr_index: usize| { - // TODO: Optimize, the following does lots of slow division and modulos. + // TODO: Optimize, the following does lots of slow division. let mut dst_index = unstr_index; // Set the sum_dims indexes to 0. for &(dim, stride) in sum_dims_and_stride.iter() { - let index = dst_index / stride % dim; - dst_index -= index * stride; + // The compiler is able to optimize the following in a single divmod op. + let (pre, post) = (dst_index / stride, dst_index % stride); + dst_index = (pre / dim) * stride + post; } dst_index }; diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index 1dde769a..4b1f3cc4 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -108,6 +108,25 @@ fn softmax() -> Result<()> { Ok(()) } +#[test] +fn sum() -> Result<()> { + let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]]; + let tensor = Tensor::new(data, &Device::Cpu)?; + assert_eq!( + tensor.sum(&[2])?.to_vec3::()?, + &[[[8], [15]], [[10], [18]]] + ); + assert_eq!( + tensor.t()?.sum(&[1])?.t()?.to_vec3::()?, + &[[[8], [15]], [[10], [18]]] + ); + assert_eq!( + tensor.sum(&[2, 1])?.to_vec3::()?, + &[[[8 + 15]], [[10 + 18]]] + ); + Ok(()) +} + #[test] fn narrow() -> Result<()> { let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];