Use the same default as pytorch for sum. (#164)

This commit is contained in:
Laurent Mazare
2023-07-13 21:32:32 +01:00
committed by GitHub
parent 57be3638d8
commit 2bfa791336
13 changed files with 123 additions and 56 deletions

View File

@ -108,56 +108,99 @@ fn sum(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
tensor.sum(&[2])?.to_vec3::<u32>()?,
tensor.sum_keepdim(&[2])?.to_vec3::<u32>()?,
&[[[8], [15]], [[10], [18]]]
);
assert_eq!(
tensor.sum(&[0])?.to_vec3::<u32>()?,
tensor.sum_keepdim(&[0])?.to_vec3::<u32>()?,
&[[[5, 2, 11], [9, 7, 17]]],
);
assert_eq!(tensor.sum(&[0, 2, 1])?.to_vec3::<u32>()?, &[[[51]]],);
assert_eq!(tensor.sum_keepdim(&[0, 2, 1])?.to_vec3::<u32>()?, &[[[51]]],);
assert_eq!(
tensor.t()?.sum(&[1])?.t()?.to_vec3::<u32>()?,
tensor.t()?.sum_keepdim(&[1])?.t()?.to_vec3::<u32>()?,
&[[[8], [15]], [[10], [18]]]
);
assert_eq!(
tensor.sum(&[2, 1])?.to_vec3::<u32>()?,
tensor.sum_keepdim(&[2, 1])?.to_vec3::<u32>()?,
&[[[8 + 15]], [[10 + 18]]]
);
let data: Vec<u32> = (0..4000u32).collect();
let tensor = Tensor::new(data.as_slice(), device)?;
assert_eq!(tensor.sum(&[0])?.to_vec1::<u32>()?, &[7998000]);
assert_eq!(tensor.sum_keepdim(&[0])?.to_vec1::<u32>()?, &[7998000]);
let tensor = tensor.reshape((2000, 2))?;
assert_eq!(tensor.sum(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
assert_eq!(tensor.sum(&[0])?.sum(&[1])?.to_vec2::<u32>()?, &[[7998000]]);
assert_eq!(tensor.sum(&[1])?.sum(&[0])?.to_vec2::<u32>()?, &[[7998000]]);
assert_eq!(tensor.sum(&[0])?.to_vec2::<u32>()?, &[[3998000, 4000000]]);
assert_eq!(tensor.sum_keepdim(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
assert_eq!(
tensor
.sum_keepdim(&[0])?
.sum_keepdim(&[1])?
.to_vec2::<u32>()?,
&[[7998000]]
);
assert_eq!(
tensor
.sum_keepdim(&[1])?
.sum_keepdim(&[0])?
.to_vec2::<u32>()?,
&[[7998000]]
);
assert_eq!(
tensor.sum_keepdim(&[0])?.to_vec2::<u32>()?,
&[[3998000, 4000000]]
);
// Make the tensor non contiguous.
let tensor = tensor.t()?.contiguous()?.t()?;
assert_eq!(tensor.sum(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
assert_eq!(tensor.sum(&[0])?.sum(&[1])?.to_vec2::<u32>()?, &[[7998000]]);
assert_eq!(tensor.sum(&[1])?.sum(&[0])?.to_vec2::<u32>()?, &[[7998000]]);
assert_eq!(tensor.sum(&[0])?.to_vec2::<u32>()?, &[[3998000, 4000000]]);
assert_eq!(tensor.sum_keepdim(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
assert_eq!(
tensor
.sum_keepdim(&[0])?
.sum_keepdim(&[1])?
.to_vec2::<u32>()?,
&[[7998000]]
);
assert_eq!(
tensor
.sum_keepdim(&[1])?
.sum_keepdim(&[0])?
.to_vec2::<u32>()?,
&[[7998000]]
);
assert_eq!(
tensor.sum_keepdim(&[0])?.to_vec2::<u32>()?,
&[[3998000, 4000000]]
);
let t1 = tensor.reshape((200, 5, 4))?;
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
for tensor in [t1, t2] {
assert_eq!(tensor.sum(&[0, 1, 2])?.to_vec3::<u32>()?, &[[[7998000]]]);
assert_eq!(
tensor.sum(&[0])?.sum(&[2])?.sum(&[1])?.to_vec3::<u32>()?,
tensor.sum_keepdim(&[0, 1, 2])?.to_vec3::<u32>()?,
&[[[7998000]]]
);
assert_eq!(
tensor.sum(&[0])?.sum(&[1, 2])?.to_vec3::<u32>()?,
tensor
.sum_keepdim(&[0])?
.sum_keepdim(&[2])?
.sum_keepdim(&[1])?
.to_vec3::<u32>()?,
&[[[7998000]]]
);
assert_eq!(
tensor.sum(&[1])?.sum(&[0, 2])?.to_vec3::<u32>()?,
tensor
.sum_keepdim(&[0])?
.sum_keepdim(&[1, 2])?
.to_vec3::<u32>()?,
&[[[7998000]]]
);
assert_eq!(
tensor.sum(&[0])?.to_vec3::<u32>()?,
tensor
.sum_keepdim(&[1])?
.sum_keepdim(&[0, 2])?
.to_vec3::<u32>()?,
&[[[7998000]]]
);
assert_eq!(
tensor.sum_keepdim(&[0])?.to_vec3::<u32>()?,
&[[
[398000, 398200, 398400, 398600],
[398800, 399000, 399200, 399400],