mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Fix Elu gradient NaN on large input (#2328)
* Fix Elu gradient NaN on large input * Reuse previously computed exp in Elu
This commit is contained in:
@ -634,7 +634,8 @@ impl Tensor {
|
||||
let zeros = arg.zeros_like()?;
|
||||
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
|
||||
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
|
||||
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
|
||||
// node == alpha * (e^x - 1) for x <= 0, reuse it
|
||||
let negative_exp_mask = (negative_mask * (*node + *alpha))?;
|
||||
let combined_mask = (positive_mask + negative_exp_mask)?;
|
||||
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
|
||||
}
|
||||
|
Reference in New Issue
Block a user