mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Simplify the parameters used by sum and sum_keepdim. (#165)
This commit is contained in:
@ -108,65 +108,53 @@ fn sum(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
tensor.sum_keepdim(&[2])?.to_vec3::<u32>()?,
|
||||
tensor.sum_keepdim(2)?.to_vec3::<u32>()?,
|
||||
&[[[8], [15]], [[10], [18]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.sum_keepdim(&[0])?.to_vec3::<u32>()?,
|
||||
tensor.sum_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[[5, 2, 11], [9, 7, 17]]],
|
||||
);
|
||||
assert_eq!(tensor.sum_keepdim(&[0, 2, 1])?.to_vec3::<u32>()?, &[[[51]]],);
|
||||
assert_eq!(tensor.sum_keepdim((0, 2, 1))?.to_vec3::<u32>()?, &[[[51]]],);
|
||||
assert_eq!(
|
||||
tensor.t()?.sum_keepdim(&[1])?.t()?.to_vec3::<u32>()?,
|
||||
tensor.t()?.sum_keepdim(1)?.t()?.to_vec3::<u32>()?,
|
||||
&[[[8], [15]], [[10], [18]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.sum_keepdim(&[2, 1])?.to_vec3::<u32>()?,
|
||||
tensor.sum_keepdim((2, 1))?.to_vec3::<u32>()?,
|
||||
&[[[8 + 15]], [[10 + 18]]]
|
||||
);
|
||||
let data: Vec<u32> = (0..4000u32).collect();
|
||||
let tensor = Tensor::new(data.as_slice(), device)?;
|
||||
assert_eq!(tensor.sum_keepdim(&[0])?.to_vec1::<u32>()?, &[7998000]);
|
||||
assert_eq!(tensor.sum_keepdim(0)?.to_vec1::<u32>()?, &[7998000]);
|
||||
let tensor = tensor.reshape((2000, 2))?;
|
||||
assert_eq!(tensor.sum_keepdim(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
|
||||
assert_eq!(tensor.sum_keepdim((0, 1))?.to_vec2::<u32>()?, &[[7998000]]);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.sum_keepdim(&[0])?
|
||||
.sum_keepdim(&[1])?
|
||||
.to_vec2::<u32>()?,
|
||||
tensor.sum_keepdim(0)?.sum_keepdim(1)?.to_vec2::<u32>()?,
|
||||
&[[7998000]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.sum_keepdim(&[1])?
|
||||
.sum_keepdim(&[0])?
|
||||
.to_vec2::<u32>()?,
|
||||
tensor.sum_keepdim(1)?.sum_keepdim(0)?.to_vec2::<u32>()?,
|
||||
&[[7998000]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.sum_keepdim(&[0])?.to_vec2::<u32>()?,
|
||||
tensor.sum_keepdim(0)?.to_vec2::<u32>()?,
|
||||
&[[3998000, 4000000]]
|
||||
);
|
||||
|
||||
// Make the tensor non contiguous.
|
||||
let tensor = tensor.t()?.contiguous()?.t()?;
|
||||
assert_eq!(tensor.sum_keepdim(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
|
||||
assert_eq!(tensor.sum_keepdim((0, 1))?.to_vec2::<u32>()?, &[[7998000]]);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.sum_keepdim(&[0])?
|
||||
.sum_keepdim(&[1])?
|
||||
.to_vec2::<u32>()?,
|
||||
tensor.sum_keepdim(0)?.sum_keepdim(1)?.to_vec2::<u32>()?,
|
||||
&[[7998000]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.sum_keepdim(&[1])?
|
||||
.sum_keepdim(&[0])?
|
||||
.to_vec2::<u32>()?,
|
||||
tensor.sum_keepdim(1)?.sum_keepdim(0)?.to_vec2::<u32>()?,
|
||||
&[[7998000]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.sum_keepdim(&[0])?.to_vec2::<u32>()?,
|
||||
tensor.sum_keepdim(0)?.to_vec2::<u32>()?,
|
||||
&[[3998000, 4000000]]
|
||||
);
|
||||
|
||||
@ -174,33 +162,33 @@ fn sum(device: &Device) -> Result<()> {
|
||||
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
||||
for tensor in [t1, t2] {
|
||||
assert_eq!(
|
||||
tensor.sum_keepdim(&[0, 1, 2])?.to_vec3::<u32>()?,
|
||||
tensor.sum_keepdim((0, 1, 2))?.to_vec3::<u32>()?,
|
||||
&[[[7998000]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.sum_keepdim(&[0])?
|
||||
.sum_keepdim(&[2])?
|
||||
.sum_keepdim(&[1])?
|
||||
.sum_keepdim(0)?
|
||||
.sum_keepdim(2)?
|
||||
.sum_keepdim(1)?
|
||||
.to_vec3::<u32>()?,
|
||||
&[[[7998000]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.sum_keepdim(&[0])?
|
||||
.sum_keepdim(&[1, 2])?
|
||||
.sum_keepdim(0)?
|
||||
.sum_keepdim((1, 2))?
|
||||
.to_vec3::<u32>()?,
|
||||
&[[[7998000]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.sum_keepdim(&[1])?
|
||||
.sum_keepdim(&[0, 2])?
|
||||
.sum_keepdim(1)?
|
||||
.sum_keepdim((0, 2))?
|
||||
.to_vec3::<u32>()?,
|
||||
&[[[7998000]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.sum_keepdim(&[0])?.to_vec3::<u32>()?,
|
||||
tensor.sum_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[
|
||||
[398000, 398200, 398400, 398600],
|
||||
[398800, 399000, 399200, 399400],
|
||||
|
Reference in New Issue
Block a user