mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
feat: add silu activation function (#1706)
* feat: add silu activation function * use silu/arg in grad * update candle-nn * use node
This commit is contained in:
@ -589,6 +589,13 @@ impl Tensor {
|
||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Silu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
||||
let sigmoid_arg = (*node / arg)?;
|
||||
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
||||
}
|
||||
Op::Elu(arg, alpha) => {
|
||||
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
Reference in New Issue
Block a user