mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Use the same default as pytorch for sum. (#164)
This commit is contained in:
@ -195,11 +195,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
let mut arg_grad = grad.sum(sum_dims.as_slice())?;
|
||||
// sum_dims has increasing values.
|
||||
for &dim in sum_dims.iter().rev() {
|
||||
arg_grad = arg_grad.squeeze(dim)?
|
||||
}
|
||||
let arg_grad = grad.sum(sum_dims.as_slice())?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.broadcast_add(&arg_grad)?
|
||||
}
|
||||
|
Reference in New Issue
Block a user