From b60064780d09ab6733f5287b322ea5cb057d3136 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 14 Feb 2024 10:27:22 +0100 Subject: [PATCH] feat: add silu activation function (#1706) * feat: add silu activation function * use silu/arg in grad * update candle-nn * use node --- candle-core/src/accelerate.rs | 32 ++++++++++++ candle-core/src/backprop.rs | 7 +++ candle-core/src/metal_backend.rs | 4 ++ candle-core/src/mkl.rs | 32 ++++++++++++ candle-core/src/op.rs | 73 ++++++++++++++++++++++++++++ candle-core/src/tensor.rs | 1 + candle-core/tests/grad_tests.rs | 13 +++++ candle-core/tests/tensor_tests.rs | 7 +++ candle-kernels/src/unary.cu | 9 ++++ candle-metal-kernels/src/lib.rs | 2 +- candle-metal-kernels/src/tests.rs | 19 ++++++++ candle-metal-kernels/src/unary.metal | 5 ++ candle-nn/src/activation.rs | 2 +- candle-nn/src/ops.rs | 5 +- 14 files changed, 206 insertions(+), 5 deletions(-) diff --git a/candle-core/src/accelerate.rs b/candle-core/src/accelerate.rs index 1cb34e19..d371d3b3 100644 --- a/candle-core/src/accelerate.rs +++ b/candle-core/src/accelerate.rs @@ -380,6 +380,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) { unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) } } +#[inline] +pub fn vs_exp_inplace(y: &mut [f32]) { + unsafe { ffi::vvexpf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) } +} + +#[inline] +pub fn vd_exp_inplace(y: &mut [f64]) { + unsafe { ffi::vvexp(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) } +} + #[inline] pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) { for (&v, y) in vs.iter().zip(ys.iter_mut()) { @@ -402,6 +412,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) { } } +#[inline] +pub fn vs_silu(vs: &[f32], ys: &mut [f32]) { + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = -v + } + vs_exp_inplace(ys); + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = v / (1.0 + *y) + } +} + +#[inline] +pub fn vd_silu(vs: &[f64], ys: &mut [f64]) { + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = -v + } + vd_exp_inplace(ys); + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = v / (1.0 + *y) + } +} + macro_rules! binary_op { ($fn_name:ident, $ty:ty, $accelerate_name:ident) => { #[inline] diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index e7e3e129..26d73ea1 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -589,6 +589,13 @@ impl Tensor { let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?; *sum_grad = sum_grad.add(&(&grad * relu_grad)?)? } + Op::Unary(arg, UnaryOp::Silu) => { + let sum_grad = grads.or_insert(arg)?; + // d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) + let sigmoid_arg = (*node / arg)?; + let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?; + *sum_grad = sum_grad.add(&(&grad * silu_grad)?)? + } Op::Elu(arg, alpha) => { // d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0 let sum_grad = grads.or_insert(arg)?; diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index ebcad786..c19d7c56 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -679,6 +679,7 @@ impl BackendStorage for MetalStorage { ("ugelu", DType::F32) => contiguous::gelu::FLOAT, ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, ("uerf", DType::F32) => contiguous::erf::FLOAT, + ("usilu", DType::F32) => contiguous::silu::FLOAT, ("uabs", DType::F32) => contiguous::abs::FLOAT, ("uceil", DType::F32) => contiguous::ceil::FLOAT, ("ufloor", DType::F32) => contiguous::floor::FLOAT, @@ -696,6 +697,7 @@ impl BackendStorage for MetalStorage { ("ugelu", DType::F16) => contiguous::gelu::HALF, ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, ("uerf", DType::F16) => contiguous::erf::HALF, + ("usilu", DType::F16) => contiguous::silu::HALF, ("uabs", DType::F16) => contiguous::abs::HALF, ("uceil", DType::F16) => contiguous::ceil::HALF, ("ufloor", DType::F16) => contiguous::floor::HALF, @@ -730,6 +732,7 @@ impl BackendStorage for MetalStorage { ("ugelu", DType::F32) => strided::gelu::FLOAT, ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT, ("uerf", DType::F32) => strided::erf::FLOAT, + ("usilu", DType::F32) => strided::silu::FLOAT, ("uabs", DType::F32) => strided::abs::FLOAT, ("uceil", DType::F32) => strided::ceil::FLOAT, ("ufloor", DType::F32) => strided::floor::FLOAT, @@ -745,6 +748,7 @@ impl BackendStorage for MetalStorage { ("ugelu", DType::F16) => strided::gelu::HALF, ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF, ("uerf", DType::F16) => strided::erf::HALF, + ("usilu", DType::F16) => strided::silu::HALF, ("uabs", DType::F16) => strided::abs::HALF, ("uceil", DType::F16) => strided::ceil::HALF, ("ufloor", DType::F16) => strided::floor::HALF, diff --git a/candle-core/src/mkl.rs b/candle-core/src/mkl.rs index 26167e86..359add74 100644 --- a/candle-core/src/mkl.rs +++ b/candle-core/src/mkl.rs @@ -333,6 +333,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) { unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) } } +#[inline] +pub fn vs_exp_inplace(y: &mut [f32]) { + unsafe { ffi::vsExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vd_exp_inplace(y: &mut [f64]) { + unsafe { ffi::vdExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) } +} + #[inline] pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) { for (&v, y) in vs.iter().zip(ys.iter_mut()) { @@ -355,6 +365,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) { } } +#[inline] +pub fn vs_silu(vs: &[f32], ys: &mut [f32]) { + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = -v + } + vs_exp_inplace(ys); + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = v / (1.0 + *y) + } +} + +#[inline] +pub fn vd_silu(vs: &[f64], ys: &mut [f64]) { + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = -v + } + vd_exp_inplace(ys); + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = v / (1.0 + *y) + } +} + macro_rules! binary_op { ($fn_name:ident, $ty:ty, $mkl_name:ident) => { #[inline] diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 868673e7..d920485c 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -61,6 +61,7 @@ pub enum UnaryOp { GeluErf, Erf, Relu, + Silu, Tanh, Floor, Ceil, @@ -390,6 +391,7 @@ pub(crate) struct Gelu; pub(crate) struct GeluErf; pub(crate) struct Erf; pub(crate) struct Relu; +pub(crate) struct Silu; pub(crate) struct Tanh; pub(crate) struct Floor; pub(crate) struct Ceil; @@ -724,6 +726,77 @@ impl UnaryOpT for Erf { } } +/// Silu operation +impl UnaryOpT for Silu { + const NAME: &'static str = "silu"; + const V: Self = Silu; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + v / (bf16::ONE + (-v).exp()) + } + #[inline(always)] + fn f16(v: f16) -> f16 { + v / (f16::ONE + (-v).exp()) + } + #[inline(always)] + fn f32(v: f32) -> f32 { + v / (1.0 + (-v).exp()) + } + #[inline(always)] + fn f64(v: f64) -> f64 { + v / (1.0 + (-v).exp()) + } + #[inline(always)] + fn u8(_: u8) -> u8 { + 0 + } + #[inline(always)] + fn u32(_: u32) -> u32 { + 0 + } + #[inline(always)] + fn i64(_: i64) -> i64 { + 0 + } + const KERNEL: &'static str = "usilu"; + + #[cfg(feature = "mkl")] + const F32_VEC: bool = true; + + #[cfg(feature = "mkl")] + #[inline(always)] + fn f32_vec(xs: &[f32], ys: &mut [f32]) { + crate::mkl::vs_silu(xs, ys) + } + + #[cfg(feature = "mkl")] + const F64_VEC: bool = true; + + #[cfg(feature = "mkl")] + #[inline(always)] + fn f64_vec(xs: &[f64], ys: &mut [f64]) { + crate::mkl::vd_silu(xs, ys) + } + + #[cfg(feature = "accelerate")] + const F32_VEC: bool = true; + + #[cfg(feature = "accelerate")] + #[inline(always)] + fn f32_vec(xs: &[f32], ys: &mut [f32]) { + crate::accelerate::vs_silu(xs, ys) + } + + #[cfg(feature = "accelerate")] + const F64_VEC: bool = true; + + #[cfg(feature = "accelerate")] + #[inline(always)] + fn f64_vec(xs: &[f64], ys: &mut [f64]) { + crate::accelerate::vd_silu(xs, ys) + } +} + impl UnaryOpT for Abs { const NAME: &'static str = "abs"; const KERNEL: &'static str = "uabs"; diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 8596c957..a1aa9338 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -508,6 +508,7 @@ impl Tensor { unary_op!(gelu_erf, GeluErf); unary_op!(erf, Erf); unary_op!(relu, Relu); + unary_op!(silu, Silu); unary_op!(ceil, Ceil); unary_op!(floor, Floor); unary_op!(round, Round); diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 16e7a82f..76987635 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -270,6 +270,19 @@ fn unary_grad(device: &Device) -> Result<()> { [0.7358, 2.0000, 0.2707, 1.0000] ); + // testing compared to pytorch nn.Silu() + let y = x.silu()?; + let grads = y.backward()?; + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [2.8577, 0.7311, 3.9281, 0.0806] + ); + assert_eq!( + test_utils::to_vec1_round(grad_x, 4)?, + [1.0881, 0.9277, 1.0527, 0.5747], + ); + // manually checked: see comments let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?; let y = x.interpolate2d(6, 6)?.reshape(36)?; diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 33bab1b6..40737e7b 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -120,6 +120,13 @@ fn unary_op(device: &Device) -> Result<()> { [0.9999, -0.9891, -0.3079, 0.9891, 0.9999] ] ); + assert_eq!( + test_utils::to_vec2_round(&tensor.silu()?, 4)?, + [ + [-0.1423, 0.7311, 3.9281, -0.0475, 0.3112], + [2.53, -0.2553, -0.1205, 1.5447, 2.6395] + ] + ); assert_eq!( test_utils::to_vec2_round(&tensor.ceil()?, 4)?, [[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]] diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index 409a337d..2256c6bb 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -55,6 +55,11 @@ __device__ __forceinline__ T relu_fwd(T x) { return maxg(x, zero); } +template +__device__ __forceinline__ T silu_fwd(T x) { + return x / (static_cast(1) + expg(-x)); +} + #define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \ extern "C" __global__ void FN_NAME( \ const size_t numel, \ @@ -103,6 +108,7 @@ UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x)) UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x)) UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x)) UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param)) +UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x)) UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param)) #endif @@ -127,6 +133,7 @@ UNARY_OP(__half, ugelu_f16, gelu_fwd(x)) UNARY_OP(__half, ugelu_erf_f16, gelu_erf_fwd(x)) UNARY_OP(__half, urelu_f16, relu_fwd(x)) UNARY_OP1(__half, uelu_f16, elu_fwd(x, param)) +UNARY_OP(__half, usilu_f16, silu_fwd(x)) UNARY_OP1(__half, upowf_f16, powg(x, param)) #endif @@ -173,5 +180,7 @@ UNARY_OP(float, urelu_f32, relu_fwd(x)) UNARY_OP(double, urelu_f64, relu_fwd(x)) UNARY_OP1(float, uelu_f32, elu_fwd(x, param)) UNARY_OP1(double, uelu_f64, elu_fwd(x, param)) +UNARY_OP(float, usilu_f32, silu_fwd(x)) +UNARY_OP(double, usilu_f64, silu_fwd(x)) UNARY_OP1(float, upowf_f32, powg(x, param)) UNARY_OP1(double, upowf_f64, powg(x, param)) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 2d27d230..33bc3453 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -183,7 +183,7 @@ macro_rules! ops{ pub mod unary { ops!( cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf, - tanh, recip + tanh, recip, silu ); } pub mod binary { diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 655161e5..459c8edb 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -231,6 +231,25 @@ fn gelu_f32() { assert_eq!(approx(results, 3), expected); } +#[test] +fn silu_f16() { + let v: Vec = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let expected: Vec = vec![-0.0, -0.27, 0.0, 0.73, 1.76, 2.86, 10.0, 20.0]; + let results = run(&v, unary::contiguous::silu::HALF); + assert_eq!(approx_f16(results, 2), expected); +} + +#[test] +fn silu_f32() { + let v: Vec = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]; + let expected: Vec = vec![-0.0, -0.269, 0.0, 0.731, 1.762, 2.858, 10.0, 20.0]; + let results = run(&v, unary::contiguous::silu::FLOAT); + assert_eq!(approx(results, 3), expected); +} + #[test] fn binary_add_f32() { let left = vec![1.0f32, 2.0, 3.0]; diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 7add58fd..1e0d5526 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -64,6 +64,9 @@ template METAL_FUNC T relu(T in){ } return in; } +template METAL_FUNC T silu(T in){ + return in / (static_cast(1) + exp(-in)); +} #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ kernel void FN_NAME( \ @@ -108,6 +111,7 @@ UNARY_OP(neg) UNARY_OP(exp) UNARY_OP(log) UNARY_OP(gelu) +UNARY_OP(silu) UNARY_OP(abs) UNARY_OP(ceil) UNARY_OP(floor) @@ -135,6 +139,7 @@ BFLOAT_UNARY_OP(neg) BFLOAT_UNARY_OP(exp) BFLOAT_UNARY_OP(log) BFLOAT_UNARY_OP(gelu) +BFLOAT_UNARY_OP(silu) BFLOAT_UNARY_OP(abs) BFLOAT_UNARY_OP(ceil) BFLOAT_UNARY_OP(floor) diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index e00463f0..60a7a6d1 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -30,7 +30,7 @@ impl super::Module for Activation { Self::Relu => xs.relu(), Self::Relu2 => xs.relu()?.sqr(), Self::Relu6 => xs.clamp(0f32, 6f32), - Self::Silu => crate::ops::silu(xs), + Self::Silu => xs.silu(), Self::Sigmoid => crate::ops::sigmoid(xs), Self::HardSigmoid => crate::ops::hard_sigmoid(xs), Self::Swiglu => crate::ops::swiglu(xs), diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index abe33350..aaec8b56 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -35,13 +35,12 @@ pub fn log_softmax(xs: &Tensor, d: D) -> Result { } pub fn silu(xs: &Tensor) -> Result { - // TODO: Should we have a specialized op for this? - xs / (xs.neg()?.exp()? + 1.0)? + xs.silu() } pub fn swiglu(xs: &Tensor) -> Result { let xs = xs.chunk(2, candle::D::Minus1)?; - crate::ops::silu(&xs[0])? * &xs[1] + &xs[0].silu()? * &xs[1] } pub fn sigmoid(xs: &Tensor) -> Result {