feat: add silu activation function (#1706)

* feat: add silu activation function

* use silu/arg in grad

* update candle-nn

* use node
This commit is contained in:
OlivierDehaene
2024-02-14 10:27:22 +01:00
committed by GitHub
parent 14010a8498
commit b60064780d
14 changed files with 206 additions and 5 deletions

View File

@ -35,13 +35,12 @@ pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
}
pub fn silu(xs: &Tensor) -> Result<Tensor> {
// TODO: Should we have a specialized op for this?
xs / (xs.neg()?.exp()? + 1.0)?
xs.silu()
}
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
let xs = xs.chunk(2, candle::D::Minus1)?;
crate::ops::silu(&xs[0])? * &xs[1]
&xs[0].silu()? * &xs[1]
}
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {