diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 78ca4b05..7e4467d1 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -124,6 +124,49 @@ fn sum(device: &Device) -> Result<()> { tensor.sum(&[2, 1])?.to_vec3::()?, &[[[8 + 15]], [[10 + 18]]] ); + let data: Vec = (0..4000u32).collect(); + let tensor = Tensor::new(data.as_slice(), device)?; + assert_eq!(tensor.sum(&[0])?.to_vec1::()?, &[7998000]); + let tensor = tensor.reshape((2000, 2))?; + assert_eq!(tensor.sum(&[0, 1])?.to_vec2::()?, &[[7998000]]); + assert_eq!(tensor.sum(&[0])?.sum(&[1])?.to_vec2::()?, &[[7998000]]); + assert_eq!(tensor.sum(&[1])?.sum(&[0])?.to_vec2::()?, &[[7998000]]); + assert_eq!(tensor.sum(&[0])?.to_vec2::()?, &[[3998000, 4000000]]); + + // Make the tensor non contiguous. + let tensor = tensor.t()?.contiguous()?.t()?; + assert_eq!(tensor.sum(&[0, 1])?.to_vec2::()?, &[[7998000]]); + assert_eq!(tensor.sum(&[0])?.sum(&[1])?.to_vec2::()?, &[[7998000]]); + assert_eq!(tensor.sum(&[1])?.sum(&[0])?.to_vec2::()?, &[[7998000]]); + assert_eq!(tensor.sum(&[0])?.to_vec2::()?, &[[3998000, 4000000]]); + + let t1 = tensor.reshape((200, 5, 4))?; + let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?; + for tensor in [t1, t2] { + assert_eq!(tensor.sum(&[0, 1, 2])?.to_vec3::()?, &[[[7998000]]]); + assert_eq!( + tensor.sum(&[0])?.sum(&[2])?.sum(&[1])?.to_vec3::()?, + &[[[7998000]]] + ); + assert_eq!( + tensor.sum(&[0])?.sum(&[1, 2])?.to_vec3::()?, + &[[[7998000]]] + ); + assert_eq!( + tensor.sum(&[1])?.sum(&[0, 2])?.to_vec3::()?, + &[[[7998000]]] + ); + assert_eq!( + tensor.sum(&[0])?.to_vec3::()?, + &[[ + [398000, 398200, 398400, 398600], + [398800, 399000, 399200, 399400], + [399600, 399800, 400000, 400200], + [400400, 400600, 400800, 401000], + [401200, 401400, 401600, 401800] + ]] + ); + } Ok(()) }