Optimize the gelu f16 opt. (#2008)

* Optimize the gelu f16 opt.

* And add a test.
This commit is contained in:
Laurent Mazare
2024-04-04 16:28:23 +02:00
committed by GitHub
parent f48c07e242
commit 30b145150f
2 changed files with 19 additions and 8 deletions

View File

@ -457,6 +457,13 @@ unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr); unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt); unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
// Hardcode the value for sqrt(2/pi)
// https://github.com/huggingface/candle/issues/1982
#[allow(clippy::excessive_precision)]
const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
#[allow(clippy::excessive_precision)]
const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
/// Tanh based approximation of the `gelu` operation /// Tanh based approximation of the `gelu` operation
/// GeluErf is the more precise one. /// GeluErf is the more precise one.
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions> /// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
@ -469,7 +476,7 @@ impl UnaryOpT for Gelu {
* v * v
* (bf16::ONE * (bf16::ONE
+ bf16::tanh( + bf16::tanh(
(bf16::from_f32_const(2.0) / bf16::PI).sqrt() bf16::from_f32_const(SQRT_TWO_OVER_PI_F32)
* v * v
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v), * (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
)) ))
@ -480,22 +487,18 @@ impl UnaryOpT for Gelu {
* v * v
* (f16::ONE * (f16::ONE
+ f16::tanh( + f16::tanh(
(f16::from_f32_const(2.0) / f16::PI).sqrt() f16::from_f32_const(SQRT_TWO_OVER_PI_F32)
* v * v
* (f16::ONE + f16::from_f32_const(0.044715) * v * v), * (f16::ONE + f16::from_f32_const(0.044715) * v * v),
)) ))
} }
#[inline(always)] #[inline(always)]
fn f32(v: f32) -> f32 { fn f32(v: f32) -> f32 {
0.5 * v 0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v)))
* (1.0
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
} }
#[inline(always)] #[inline(always)]
fn f64(v: f64) -> f64 { fn f64(v: f64) -> f64 {
0.5 * v 0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v)))
* (1.0
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
} }
#[inline(always)] #[inline(always)]
fn u8(_: u8) -> u8 { fn u8(_: u8) -> u8 {

View File

@ -106,6 +106,14 @@ fn unary_op(device: &Device) -> Result<()> {
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933] [2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
] ]
); );
let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?;
assert_eq!(
test_utils::to_vec2_round(&t_f16, 2)?,
[
[-0.0, 0.84, 4.0, -0.05, 0.35],
[2.69, -0.07, -0.11, 1.73, 2.79]
],
);
assert_eq!( assert_eq!(
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?, test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
[ [