mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
feat: add silu activation function (#1706)
* feat: add silu activation function * use silu/arg in grad * update candle-nn * use node
This commit is contained in:
@ -183,7 +183,7 @@ macro_rules! ops{
|
||||
pub mod unary {
|
||||
ops!(
|
||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
||||
tanh, recip
|
||||
tanh, recip, silu
|
||||
);
|
||||
}
|
||||
pub mod binary {
|
||||
|
@ -231,6 +231,25 @@ fn gelu_f32() {
|
||||
assert_eq!(approx(results, 3), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn silu_f16() {
|
||||
let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
let expected: Vec<f32> = vec![-0.0, -0.27, 0.0, 0.73, 1.76, 2.86, 10.0, 20.0];
|
||||
let results = run(&v, unary::contiguous::silu::HALF);
|
||||
assert_eq!(approx_f16(results, 2), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn silu_f32() {
|
||||
let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0];
|
||||
let expected: Vec<f32> = vec![-0.0, -0.269, 0.0, 0.731, 1.762, 2.858, 10.0, 20.0];
|
||||
let results = run(&v, unary::contiguous::silu::FLOAT);
|
||||
assert_eq!(approx(results, 3), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn binary_add_f32() {
|
||||
let left = vec![1.0f32, 2.0, 3.0];
|
||||
|
@ -64,6 +64,9 @@ template <typename T> METAL_FUNC T relu(T in){
|
||||
}
|
||||
return in;
|
||||
}
|
||||
template <typename T> METAL_FUNC T silu(T in){
|
||||
return in / (static_cast<T>(1) + exp(-in));
|
||||
}
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
@ -108,6 +111,7 @@ UNARY_OP(neg)
|
||||
UNARY_OP(exp)
|
||||
UNARY_OP(log)
|
||||
UNARY_OP(gelu)
|
||||
UNARY_OP(silu)
|
||||
UNARY_OP(abs)
|
||||
UNARY_OP(ceil)
|
||||
UNARY_OP(floor)
|
||||
@ -135,6 +139,7 @@ BFLOAT_UNARY_OP(neg)
|
||||
BFLOAT_UNARY_OP(exp)
|
||||
BFLOAT_UNARY_OP(log)
|
||||
BFLOAT_UNARY_OP(gelu)
|
||||
BFLOAT_UNARY_OP(silu)
|
||||
BFLOAT_UNARY_OP(abs)
|
||||
BFLOAT_UNARY_OP(ceil)
|
||||
BFLOAT_UNARY_OP(floor)
|
||||
|
Reference in New Issue
Block a user