Use the same default as pytorch for sum. (#164)

This commit is contained in:
Laurent Mazare
2023-07-13 21:32:32 +01:00
committed by GitHub
parent 57be3638d8
commit 2bfa791336
13 changed files with 123 additions and 56 deletions

View File

@ -195,11 +195,7 @@ impl Tensor {
}
}
let mut arg_grad = grad.sum(sum_dims.as_slice())?;
// sum_dims has increasing values.
for &dim in sum_dims.iter().rev() {
arg_grad = arg_grad.squeeze(dim)?
}
let arg_grad = grad.sum(sum_dims.as_slice())?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.broadcast_add(&arg_grad)?
}