Creating Gelu op (no backward).

This commit is contained in:
Nicolas Patry
2023-06-22 21:56:46 +02:00
parent 4ffdeb4e23
commit fd21c708ab
3 changed files with 43 additions and 0 deletions

View File

@ -22,6 +22,7 @@ pub(crate) enum Op {
Sqrt(Tensor),
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
Gelu(Tensor),
// TODO: Support for custom ops.
}
@ -52,6 +53,7 @@ pub(crate) struct Sub;
pub(crate) struct Neg;
pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
impl BinaryOp for Add {
const NAME: &'static str = "add";
@ -136,3 +138,29 @@ impl UnaryOp for Sqrt {
const KERNEL_F32: &'static str = "usqrt_f32";
const KERNEL_F64: &'static str = "usqrt_f64";
}
/// `gelu` operation
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
#[inline]
pub fn gelu_f32(v: f32) -> f32 {
0.5 * (v)
* (1.0 + f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
}
/// `gelu` operation
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
#[inline]
pub fn gelu_f64(v: f64) -> f64 {
0.5 * (v)
* (1.0 + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
}
impl UnaryOp for Gelu {
const NAME: &'static str = "gelu";
fn f32(v1: f32) -> f32 {
gelu_f32(v1)
}
fn f64(v1: f64) -> f64 {
gelu_f64(v1)
}
const KERNEL_F32: &'static str = "gelu_f32";
const KERNEL_F64: &'static str = "gelu_f64";
}