mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Fix the silu gradient issue on 0. (#2083)
This commit is contained in:
@ -624,7 +624,7 @@ impl Tensor {
|
||||
Op::Unary(arg, UnaryOp::Silu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
||||
let sigmoid_arg = (*node / arg)?;
|
||||
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
|
||||
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
||||
}
|
||||
|
Reference in New Issue
Block a user