Fix the minimum/maximum gradient computations. (#534)

This commit is contained in:
Laurent Mazare
2023-08-21 08:28:41 +01:00
committed by GitHub
parent 912561614f
commit d70cffdab6
2 changed files with 35 additions and 3 deletions

View File

@ -164,13 +164,18 @@ impl Tensor {
}
Op::Binary(lhs, rhs, BinaryOp::Minimum)
| Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
let lhs_grad = node.eq(lhs)?.to_dtype(grad.dtype())?.mul(&grad)?;
let mask_lhs = node.eq(lhs)?.to_dtype(grad.dtype())?;
let mask_rhs = node.eq(rhs)?.to_dtype(grad.dtype())?;
// If both masks are 1 one the same point, we want to scale the
// gradient by 0.5 rather than 1.
let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + 1.)?)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
let rhs_grad = node.eq(rhs)?.to_dtype(grad.dtype())?.mul(&grad)?;
let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + 1.)?)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::WhereCond(pred, t, f) => {
let zeros = grad.zeros_like()?;