From a2a20aeeccdcbcbd00e48d6f7ac97b2435b2378c Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 2 Nov 2023 20:01:34 +0100 Subject: [PATCH] Add the swiglu activation from the chatglm PR. (#1246) --- candle-nn/src/activation.rs | 2 ++ candle-nn/src/ops.rs | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index 799e2ee2..77e709d2 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -14,6 +14,7 @@ pub enum Activation { Silu, Sigmoid, HardSigmoid, + Swiglu, Swish, HardSwish, Elu(f64), @@ -32,6 +33,7 @@ impl super::Module for Activation { Self::Silu => crate::ops::silu(xs), Self::Sigmoid => crate::ops::sigmoid(xs), Self::HardSigmoid => crate::ops::hard_sigmoid(xs), + Self::Swiglu => crate::ops::swiglu(xs), Self::Swish => xs * crate::ops::sigmoid(xs)?, Self::HardSwish => xs * crate::ops::hard_sigmoid(xs)?, &Self::Elu(alpha) => xs.elu(alpha), diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index a51ef2e3..a0269e59 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -39,6 +39,11 @@ pub fn silu(xs: &Tensor) -> Result { xs / (xs.neg()?.exp()? + 1.0)? } +pub fn swiglu(xs: &Tensor) -> Result { + let xs = xs.chunk(2, candle::D::Minus1)?; + crate::ops::silu(&xs[0])? * &xs[1] +} + pub fn sigmoid(xs: &Tensor) -> Result { // TODO: Should we have a specialized op for this? (xs.neg()?.exp()? + 1.0)?.recip()