From c0a559d427c04c0484c14f6052b2ea268af10c9d Mon Sep 17 00:00:00 2001 From: MilkFather <31627231+MilkFather@users.noreply.github.com> Date: Sun, 4 Aug 2024 17:24:17 +0800 Subject: [PATCH] optimize gradient for silu a bit (#2393) --- candle-core/src/backprop.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 3ea03b0b..a5566774 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -623,9 +623,9 @@ impl Tensor { } Op::Unary(arg, UnaryOp::Silu) => { let sum_grad = grads.or_insert(arg)?; - // d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) + // d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) = sigmoid(x) * (1 - node) + node let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?; - let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?; + let silu_grad = &sigmoid_arg * (1. - *node) + *node; *sum_grad = sum_grad.add(&(&grad * silu_grad)?)? } Op::Elu(arg, alpha) => {