diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index 799e2ee2..77e709d2 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -14,6 +14,7 @@ pub enum Activation { Silu, Sigmoid, HardSigmoid, + Swiglu, Swish, HardSwish, Elu(f64), @@ -32,6 +33,7 @@ impl super::Module for Activation { Self::Silu => crate::ops::silu(xs), Self::Sigmoid => crate::ops::sigmoid(xs), Self::HardSigmoid => crate::ops::hard_sigmoid(xs), + Self::Swiglu => crate::ops::swiglu(xs), Self::Swish => xs * crate::ops::sigmoid(xs)?, Self::HardSwish => xs * crate::ops::hard_sigmoid(xs)?, &Self::Elu(alpha) => xs.elu(alpha), diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index a51ef2e3..a0269e59 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -39,6 +39,11 @@ pub fn silu(xs: &Tensor) -> Result { xs / (xs.neg()?.exp()? + 1.0)? } +pub fn swiglu(xs: &Tensor) -> Result { + let xs = xs.chunk(2, candle::D::Minus1)?; + crate::ops::silu(&xs[0])? * &xs[1] +} + pub fn sigmoid(xs: &Tensor) -> Result { // TODO: Should we have a specialized op for this? (xs.neg()?.exp()? + 1.0)?.recip()