From 2ca086939f91f5d8ccec745e47648f74fa520988 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 30 Nov 2023 11:40:39 +0100 Subject: [PATCH 1/2] Put back affine strided tests --- candle-metal-kernels/src/tests.rs | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 66dc8d01..59f54fa9 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -295,7 +295,7 @@ fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { output.read_to_vec::(v.len()) } -fn _run_affine_strided( +fn run_affine_strided( v: &[T], shape: &[usize], strides: &[usize], @@ -314,7 +314,7 @@ fn _run_affine_strided( &device, command_buffer, &kernels, - "affine_float", + "affine_float_strided", shape, &input, strides, @@ -327,7 +327,8 @@ fn _run_affine_strided( command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::(v.len()) + let len: usize = shape.iter().product(); + output.read_to_vec::(len) } #[test] @@ -345,15 +346,17 @@ fn affine() { assert_eq!(result, vec![2.6; 40_000]); } -// #[test] -// fn affine_strided() { -// let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; -// let mul = 1.5; -// let add = 1.1; -// let result = run_affine_(&input, mul, add); -// assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]); - -// } +#[test] +fn affine_strided() { + let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let mul = 1.5; + let add = 1.1; + let shape = [4]; + let strides = [2]; + let result = run_affine_strided(&input, &shape, &strides, mul, add); + // 1 on 2 + assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]); +} #[test] fn index_select() { From 6e25822d4fcd3321f1e078706683b990780ba1ae Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Wed, 6 Dec 2023 09:59:44 -0500 Subject: [PATCH 2/2] Fix gelu for large x --- candle-metal-kernels/src/tests.rs | 23 +++++++++++++++++++++-- candle-metal-kernels/src/unary.metal | 11 ++++++++--- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 59f54fa9..37b07167 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -205,6 +205,25 @@ fn cos_strided_random() { ); } +#[test] +fn gelu_f16() { + let v: Vec = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let expected: Vec = vec![-0.0, -0.16, 0.0, 0.84, 1.96, 3.0, 10.0, 20.0]; + let results = run(&v, unary::contiguous::gelu::HALF); + assert_eq!(approx_f16(results, 2), expected); +} + +#[test] +fn gelu_f32() { + let v: Vec = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]; + let expected: Vec = vec![-0.0, -0.159, 0.0, 0.841, 1.955, 2.996, 10.0, 20.0]; + let results = run(&v, unary::contiguous::gelu::FLOAT); + assert_eq!(approx(results, 3), expected); +} + #[test] fn binary_add_f32() { let left = vec![1.0f32, 2.0, 3.0]; @@ -527,8 +546,8 @@ fn cos_f16() { .collect(); let results = run(&v, unary::contiguous::cos::HALF); let expected: Vec = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect(); - assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]); - assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]); + assert_eq!(approx_f16(results, 2), vec![0.54, -0.42, -0.99]); + assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]); } fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 88139af9..529162bd 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -42,9 +42,14 @@ template METAL_FUNC T erf(T in){ return T(sign*y); } -template METAL_FUNC T id(T in){ return in; } -template METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); } -template METAL_FUNC T gelu(T x){ +template METAL_FUNC T id(T in) { return in; } +template METAL_FUNC T gelu_erf(T x) { + return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); +} +template METAL_FUNC T gelu(T x) { + if (x > 5) { + return x; + } T x_sq = x * x; T x_cube = x_sq * x; T alpha = x + static_cast(0.044715) * x_cube;