mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
feat: add silu activation function (#1706)
* feat: add silu activation function * use silu/arg in grad * update candle-nn * use node
This commit is contained in:
@ -380,6 +380,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_exp_inplace(y: &mut [f32]) {
|
||||
unsafe { ffi::vvexpf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_exp_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vvexp(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
@ -402,6 +412,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vs_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vd_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
||||
#[inline]
|
||||
|
@ -589,6 +589,13 @@ impl Tensor {
|
||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Silu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
||||
let sigmoid_arg = (*node / arg)?;
|
||||
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
||||
}
|
||||
Op::Elu(arg, alpha) => {
|
||||
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
@ -679,6 +679,7 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||
("usilu", DType::F32) => contiguous::silu::FLOAT,
|
||||
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||
@ -696,6 +697,7 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||
("usilu", DType::F16) => contiguous::silu::HALF,
|
||||
("uabs", DType::F16) => contiguous::abs::HALF,
|
||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||
@ -730,6 +732,7 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => strided::erf::FLOAT,
|
||||
("usilu", DType::F32) => strided::silu::FLOAT,
|
||||
("uabs", DType::F32) => strided::abs::FLOAT,
|
||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||
@ -745,6 +748,7 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F16) => strided::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => strided::erf::HALF,
|
||||
("usilu", DType::F16) => strided::silu::HALF,
|
||||
("uabs", DType::F16) => strided::abs::HALF,
|
||||
("uceil", DType::F16) => strided::ceil::HALF,
|
||||
("ufloor", DType::F16) => strided::floor::HALF,
|
||||
|
@ -333,6 +333,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_exp_inplace(y: &mut [f32]) {
|
||||
unsafe { ffi::vsExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_exp_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vdExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
@ -355,6 +365,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vs_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vd_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
|
||||
#[inline]
|
||||
|
@ -61,6 +61,7 @@ pub enum UnaryOp {
|
||||
GeluErf,
|
||||
Erf,
|
||||
Relu,
|
||||
Silu,
|
||||
Tanh,
|
||||
Floor,
|
||||
Ceil,
|
||||
@ -390,6 +391,7 @@ pub(crate) struct Gelu;
|
||||
pub(crate) struct GeluErf;
|
||||
pub(crate) struct Erf;
|
||||
pub(crate) struct Relu;
|
||||
pub(crate) struct Silu;
|
||||
pub(crate) struct Tanh;
|
||||
pub(crate) struct Floor;
|
||||
pub(crate) struct Ceil;
|
||||
@ -724,6 +726,77 @@ impl UnaryOpT for Erf {
|
||||
}
|
||||
}
|
||||
|
||||
/// Silu operation
|
||||
impl UnaryOpT for Silu {
|
||||
const NAME: &'static str = "silu";
|
||||
const V: Self = Silu;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v / (bf16::ONE + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v / (f16::ONE + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v / (1.0 + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v / (1.0 + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
const KERNEL: &'static str = "usilu";
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F32_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::mkl::vs_silu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F64_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::mkl::vd_silu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F32_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::accelerate::vs_silu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F64_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::accelerate::vd_silu(xs, ys)
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Abs {
|
||||
const NAME: &'static str = "abs";
|
||||
const KERNEL: &'static str = "uabs";
|
||||
|
@ -508,6 +508,7 @@ impl Tensor {
|
||||
unary_op!(gelu_erf, GeluErf);
|
||||
unary_op!(erf, Erf);
|
||||
unary_op!(relu, Relu);
|
||||
unary_op!(silu, Silu);
|
||||
unary_op!(ceil, Ceil);
|
||||
unary_op!(floor, Floor);
|
||||
unary_op!(round, Round);
|
||||
|
@ -270,6 +270,19 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||
);
|
||||
|
||||
// testing compared to pytorch nn.Silu()
|
||||
let y = x.silu()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[2.8577, 0.7311, 3.9281, 0.0806]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[1.0881, 0.9277, 1.0527, 0.5747],
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||
|
@ -120,6 +120,13 @@ fn unary_op(device: &Device) -> Result<()> {
|
||||
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.silu()?, 4)?,
|
||||
[
|
||||
[-0.1423, 0.7311, 3.9281, -0.0475, 0.3112],
|
||||
[2.53, -0.2553, -0.1205, 1.5447, 2.6395]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
||||
|
@ -55,6 +55,11 @@ __device__ __forceinline__ T relu_fwd(T x) {
|
||||
return maxg(x, zero);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T silu_fwd(T x) {
|
||||
return x / (static_cast<scalar_t>(1) + expg(-x));
|
||||
}
|
||||
|
||||
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
@ -103,6 +108,7 @@ UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
|
||||
UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x))
|
||||
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
|
||||
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
|
||||
#endif
|
||||
|
||||
@ -127,6 +133,7 @@ UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
|
||||
UNARY_OP(__half, ugelu_erf_f16, gelu_erf_fwd(x))
|
||||
UNARY_OP(__half, urelu_f16, relu_fwd(x))
|
||||
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
|
||||
UNARY_OP(__half, usilu_f16, silu_fwd(x))
|
||||
UNARY_OP1(__half, upowf_f16, powg(x, param))
|
||||
#endif
|
||||
|
||||
@ -173,5 +180,7 @@ UNARY_OP(float, urelu_f32, relu_fwd(x))
|
||||
UNARY_OP(double, urelu_f64, relu_fwd(x))
|
||||
UNARY_OP1(float, uelu_f32, elu_fwd(x, param))
|
||||
UNARY_OP1(double, uelu_f64, elu_fwd(x, param))
|
||||
UNARY_OP(float, usilu_f32, silu_fwd(x))
|
||||
UNARY_OP(double, usilu_f64, silu_fwd(x))
|
||||
UNARY_OP1(float, upowf_f32, powg(x, param))
|
||||
UNARY_OP1(double, upowf_f64, powg(x, param))
|
||||
|
@ -183,7 +183,7 @@ macro_rules! ops{
|
||||
pub mod unary {
|
||||
ops!(
|
||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
||||
tanh, recip
|
||||
tanh, recip, silu
|
||||
);
|
||||
}
|
||||
pub mod binary {
|
||||
|
@ -231,6 +231,25 @@ fn gelu_f32() {
|
||||
assert_eq!(approx(results, 3), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn silu_f16() {
|
||||
let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
let expected: Vec<f32> = vec![-0.0, -0.27, 0.0, 0.73, 1.76, 2.86, 10.0, 20.0];
|
||||
let results = run(&v, unary::contiguous::silu::HALF);
|
||||
assert_eq!(approx_f16(results, 2), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn silu_f32() {
|
||||
let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0];
|
||||
let expected: Vec<f32> = vec![-0.0, -0.269, 0.0, 0.731, 1.762, 2.858, 10.0, 20.0];
|
||||
let results = run(&v, unary::contiguous::silu::FLOAT);
|
||||
assert_eq!(approx(results, 3), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn binary_add_f32() {
|
||||
let left = vec![1.0f32, 2.0, 3.0];
|
||||
|
@ -64,6 +64,9 @@ template <typename T> METAL_FUNC T relu(T in){
|
||||
}
|
||||
return in;
|
||||
}
|
||||
template <typename T> METAL_FUNC T silu(T in){
|
||||
return in / (static_cast<T>(1) + exp(-in));
|
||||
}
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
@ -108,6 +111,7 @@ UNARY_OP(neg)
|
||||
UNARY_OP(exp)
|
||||
UNARY_OP(log)
|
||||
UNARY_OP(gelu)
|
||||
UNARY_OP(silu)
|
||||
UNARY_OP(abs)
|
||||
UNARY_OP(ceil)
|
||||
UNARY_OP(floor)
|
||||
@ -135,6 +139,7 @@ BFLOAT_UNARY_OP(neg)
|
||||
BFLOAT_UNARY_OP(exp)
|
||||
BFLOAT_UNARY_OP(log)
|
||||
BFLOAT_UNARY_OP(gelu)
|
||||
BFLOAT_UNARY_OP(silu)
|
||||
BFLOAT_UNARY_OP(abs)
|
||||
BFLOAT_UNARY_OP(ceil)
|
||||
BFLOAT_UNARY_OP(floor)
|
||||
|
@ -30,7 +30,7 @@ impl super::Module for Activation {
|
||||
Self::Relu => xs.relu(),
|
||||
Self::Relu2 => xs.relu()?.sqr(),
|
||||
Self::Relu6 => xs.clamp(0f32, 6f32),
|
||||
Self::Silu => crate::ops::silu(xs),
|
||||
Self::Silu => xs.silu(),
|
||||
Self::Sigmoid => crate::ops::sigmoid(xs),
|
||||
Self::HardSigmoid => crate::ops::hard_sigmoid(xs),
|
||||
Self::Swiglu => crate::ops::swiglu(xs),
|
||||
|
@ -35,13 +35,12 @@ pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
|
||||
}
|
||||
|
||||
pub fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: Should we have a specialized op for this?
|
||||
xs / (xs.neg()?.exp()? + 1.0)?
|
||||
xs.silu()
|
||||
}
|
||||
|
||||
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.chunk(2, candle::D::Minus1)?;
|
||||
crate::ops::silu(&xs[0])? * &xs[1]
|
||||
&xs[0].silu()? * &xs[1]
|
||||
}
|
||||
|
||||
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
|
||||
|
Reference in New Issue
Block a user