mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Fix gelu for large x
This commit is contained in:
@ -205,6 +205,25 @@ fn cos_strided_random() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn gelu_f16() {
|
||||||
|
let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]
|
||||||
|
.iter()
|
||||||
|
.map(|v| f16::from_f32(*v))
|
||||||
|
.collect();
|
||||||
|
let expected: Vec<f32> = vec![-0.0, -0.16, 0.0, 0.84, 1.96, 3.0, 10.0, 20.0];
|
||||||
|
let results = run(&v, unary::contiguous::gelu::HALF);
|
||||||
|
assert_eq!(approx_f16(results, 2), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn gelu_f32() {
|
||||||
|
let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0];
|
||||||
|
let expected: Vec<f32> = vec![-0.0, -0.159, 0.0, 0.841, 1.955, 2.996, 10.0, 20.0];
|
||||||
|
let results = run(&v, unary::contiguous::gelu::FLOAT);
|
||||||
|
assert_eq!(approx(results, 3), expected);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn binary_add_f32() {
|
fn binary_add_f32() {
|
||||||
let left = vec![1.0f32, 2.0, 3.0];
|
let left = vec![1.0f32, 2.0, 3.0];
|
||||||
@ -527,8 +546,8 @@ fn cos_f16() {
|
|||||||
.collect();
|
.collect();
|
||||||
let results = run(&v, unary::contiguous::cos::HALF);
|
let results = run(&v, unary::contiguous::cos::HALF);
|
||||||
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
|
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
|
||||||
assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]);
|
assert_eq!(approx_f16(results, 2), vec![0.54, -0.42, -0.99]);
|
||||||
assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
|
assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
|
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
|
||||||
|
@ -42,9 +42,14 @@ template <typename T> METAL_FUNC T erf(T in){
|
|||||||
|
|
||||||
return T(sign*y);
|
return T(sign*y);
|
||||||
}
|
}
|
||||||
template <typename T> METAL_FUNC T id(T in){ return in; }
|
template <typename T> METAL_FUNC T id(T in) { return in; }
|
||||||
template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); }
|
template <typename T> METAL_FUNC T gelu_erf(T x) {
|
||||||
template <typename T> METAL_FUNC T gelu(T x){
|
return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2);
|
||||||
|
}
|
||||||
|
template <typename T> METAL_FUNC T gelu(T x) {
|
||||||
|
if (x > 5) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
T x_sq = x * x;
|
T x_sq = x * x;
|
||||||
T x_cube = x_sq * x;
|
T x_cube = x_sq * x;
|
||||||
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
||||||
|
Reference in New Issue
Block a user