Fix Elu gradient NaN on large input (#2328)

* Fix Elu gradient NaN on large input

* Reuse previously computed exp in Elu
This commit is contained in:
Alexey Gerasev
2024-07-16 19:41:16 +07:00
committed by GitHub
parent 30cdd769f9
commit 6a4741bbf9

View File

@ -634,7 +634,8 @@ impl Tensor {
let zeros = arg.zeros_like()?;
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
// node == alpha * (e^x - 1) for x <= 0, reuse it
let negative_exp_mask = (negative_mask * (*node + *alpha))?;
let combined_mask = (positive_mask + negative_exp_mask)?;
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
}