Use the same default as pytorch for sum. (#164)

This commit is contained in:
Laurent Mazare
2023-07-13 21:32:32 +01:00
committed by GitHub
parent 57be3638d8
commit 2bfa791336
13 changed files with 123 additions and 56 deletions

View File

@ -27,18 +27,18 @@ fn main() -> Result<()> {
let xys_cpu = cos_sin(n, &Device::Cpu)?;
let xys = cos_sin(n, &device)?;
println!("{xys_cpu:?} {xys:?}");
let sum_cpu = xys_cpu.sum(&[1])?;
println!("{sum_cpu}");
let sum = xys.sum(&[1])?;
println!("{sum}");
let sum_keepdim_cpu = xys_cpu.sum_keepdim(&[1])?;
println!("{sum_keepdim_cpu}");
let sum_keepdim = xys.sum_keepdim(&[1])?;
println!("{sum_keepdim}");
let start = std::time::Instant::now();
let n_iters = 100;
let mut v = 0f32;
for _i in 0..n_iters {
let sum = xys.sum(&[1])?;
let sum = sum.sum(&[0])?;
let sum: f32 = sum.reshape(&[])?.to_scalar()?;
v += sum;
let sum_keepdim = xys.sum_keepdim(&[1])?;
let sum_keepdim = sum_keepdim.sum_keepdim(&[0])?;
let sum_keepdim: f32 = sum_keepdim.reshape(&[])?.to_scalar()?;
v += sum_keepdim;
}
let elapsed = start.elapsed();
if v > 0. {