From 6e25822d4fcd3321f1e078706683b990780ba1ae Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Wed, 6 Dec 2023 09:59:44 -0500 Subject: [PATCH] Fix gelu for large x --- candle-metal-kernels/src/tests.rs | 23 +++++++++++++++++++++-- candle-metal-kernels/src/unary.metal | 11 ++++++++--- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 59f54fa9..37b07167 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -205,6 +205,25 @@ fn cos_strided_random() { ); } +#[test] +fn gelu_f16() { + let v: Vec = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let expected: Vec = vec![-0.0, -0.16, 0.0, 0.84, 1.96, 3.0, 10.0, 20.0]; + let results = run(&v, unary::contiguous::gelu::HALF); + assert_eq!(approx_f16(results, 2), expected); +} + +#[test] +fn gelu_f32() { + let v: Vec = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]; + let expected: Vec = vec![-0.0, -0.159, 0.0, 0.841, 1.955, 2.996, 10.0, 20.0]; + let results = run(&v, unary::contiguous::gelu::FLOAT); + assert_eq!(approx(results, 3), expected); +} + #[test] fn binary_add_f32() { let left = vec![1.0f32, 2.0, 3.0]; @@ -527,8 +546,8 @@ fn cos_f16() { .collect(); let results = run(&v, unary::contiguous::cos::HALF); let expected: Vec = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect(); - assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]); - assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]); + assert_eq!(approx_f16(results, 2), vec![0.54, -0.42, -0.99]); + assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]); } fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 88139af9..529162bd 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -42,9 +42,14 @@ template METAL_FUNC T erf(T in){ return T(sign*y); } -template METAL_FUNC T id(T in){ return in; } -template METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); } -template METAL_FUNC T gelu(T x){ +template METAL_FUNC T id(T in) { return in; } +template METAL_FUNC T gelu_erf(T x) { + return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); +} +template METAL_FUNC T gelu(T x) { + if (x > 5) { + return x; + } T x_sq = x * x; T x_cube = x_sq * x; T alpha = x + static_cast(0.044715) * x_cube;