Fix the silu gradient issue on 0. (#2083)

This commit is contained in:
Laurent Mazare
2024-04-18 14:31:41 +02:00
committed by GitHub
parent 8de0ce6cba
commit 1690ab45d2

View File

@ -624,7 +624,7 @@ impl Tensor {
Op::Unary(arg, UnaryOp::Silu) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
let sigmoid_arg = (*node / arg)?;
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
}