diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index 4b1f3cc4..e61d58ed 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -116,6 +116,10 @@ fn sum() -> Result<()> { tensor.sum(&[2])?.to_vec3::()?, &[[[8], [15]], [[10], [18]]] ); + assert_eq!( + tensor.sum(&[0])?.to_vec3::()?, + &[[[5, 2, 11], [9, 7, 17]]], + ); assert_eq!( tensor.t()?.sum(&[1])?.t()?.to_vec3::()?, &[[[8], [15]], [[10], [18]]]