mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
feat: add silu activation function (#1706)
* feat: add silu activation function * use silu/arg in grad * update candle-nn * use node
This commit is contained in:
@ -380,6 +380,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_exp_inplace(y: &mut [f32]) {
|
||||
unsafe { ffi::vvexpf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_exp_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vvexp(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
@ -402,6 +412,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vs_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vd_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
||||
#[inline]
|
||||
|
@ -589,6 +589,13 @@ impl Tensor {
|
||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Silu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
||||
let sigmoid_arg = (*node / arg)?;
|
||||
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
||||
}
|
||||
Op::Elu(arg, alpha) => {
|
||||
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
@ -679,6 +679,7 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||
("usilu", DType::F32) => contiguous::silu::FLOAT,
|
||||
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||
@ -696,6 +697,7 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||
("usilu", DType::F16) => contiguous::silu::HALF,
|
||||
("uabs", DType::F16) => contiguous::abs::HALF,
|
||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||
@ -730,6 +732,7 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => strided::erf::FLOAT,
|
||||
("usilu", DType::F32) => strided::silu::FLOAT,
|
||||
("uabs", DType::F32) => strided::abs::FLOAT,
|
||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||
@ -745,6 +748,7 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F16) => strided::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => strided::erf::HALF,
|
||||
("usilu", DType::F16) => strided::silu::HALF,
|
||||
("uabs", DType::F16) => strided::abs::HALF,
|
||||
("uceil", DType::F16) => strided::ceil::HALF,
|
||||
("ufloor", DType::F16) => strided::floor::HALF,
|
||||
|
@ -333,6 +333,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_exp_inplace(y: &mut [f32]) {
|
||||
unsafe { ffi::vsExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_exp_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vdExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
@ -355,6 +365,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vs_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vd_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
|
||||
#[inline]
|
||||
|
@ -61,6 +61,7 @@ pub enum UnaryOp {
|
||||
GeluErf,
|
||||
Erf,
|
||||
Relu,
|
||||
Silu,
|
||||
Tanh,
|
||||
Floor,
|
||||
Ceil,
|
||||
@ -390,6 +391,7 @@ pub(crate) struct Gelu;
|
||||
pub(crate) struct GeluErf;
|
||||
pub(crate) struct Erf;
|
||||
pub(crate) struct Relu;
|
||||
pub(crate) struct Silu;
|
||||
pub(crate) struct Tanh;
|
||||
pub(crate) struct Floor;
|
||||
pub(crate) struct Ceil;
|
||||
@ -724,6 +726,77 @@ impl UnaryOpT for Erf {
|
||||
}
|
||||
}
|
||||
|
||||
/// Silu operation
|
||||
impl UnaryOpT for Silu {
|
||||
const NAME: &'static str = "silu";
|
||||
const V: Self = Silu;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v / (bf16::ONE + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v / (f16::ONE + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v / (1.0 + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v / (1.0 + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
const KERNEL: &'static str = "usilu";
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F32_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::mkl::vs_silu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F64_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::mkl::vd_silu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F32_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::accelerate::vs_silu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F64_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::accelerate::vd_silu(xs, ys)
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Abs {
|
||||
const NAME: &'static str = "abs";
|
||||
const KERNEL: &'static str = "uabs";
|
||||
|
@ -508,6 +508,7 @@ impl Tensor {
|
||||
unary_op!(gelu_erf, GeluErf);
|
||||
unary_op!(erf, Erf);
|
||||
unary_op!(relu, Relu);
|
||||
unary_op!(silu, Silu);
|
||||
unary_op!(ceil, Ceil);
|
||||
unary_op!(floor, Floor);
|
||||
unary_op!(round, Round);
|
||||
|
@ -270,6 +270,19 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||
);
|
||||
|
||||
// testing compared to pytorch nn.Silu()
|
||||
let y = x.silu()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[2.8577, 0.7311, 3.9281, 0.0806]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[1.0881, 0.9277, 1.0527, 0.5747],
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||
|
@ -120,6 +120,13 @@ fn unary_op(device: &Device) -> Result<()> {
|
||||
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.silu()?, 4)?,
|
||||
[
|
||||
[-0.1423, 0.7311, 3.9281, -0.0475, 0.3112],
|
||||
[2.53, -0.2553, -0.1205, 1.5447, 2.6395]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
||||
|
Reference in New Issue
Block a user