diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index a9d5a6a6..d6293aa4 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -634,7 +634,8 @@ impl Tensor { let zeros = arg.zeros_like()?; let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?; let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?; - let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?; + // node == alpha * (e^x - 1) for x <= 0, reuse it + let negative_exp_mask = (negative_mask * (*node + *alpha))?; let combined_mask = (positive_mask + negative_exp_mask)?; *sum_grad = sum_grad.add(&(grad * combined_mask)?)? }