Preliminary support for mkl based gelu. (#187)

* Preliminary support for mkl based gelu.

* Add the vectorized function for unary ops.

* Get the mkl specialized gelu to work.
This commit is contained in:
Laurent Mazare
2023-07-18 07:48:48 +01:00
committed by GitHub
parent b8abe2bb4b
commit d73df74cb2
3 changed files with 135 additions and 12 deletions

View File

@ -60,6 +60,17 @@ pub(crate) trait UnaryOp {
fn f64(v1: f64) -> f64;
fn u8(v1: u8) -> u8;
fn u32(v1: u32) -> u32;
// There is no very good way to represent optional function in traits so we go for an explicit
// boolean flag to mark the function as existing.
const BF16_VEC: bool = false;
fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {}
const F16_VEC: bool = false;
fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {}
const F32_VEC: bool = false;
fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {}
const F64_VEC: bool = false;
fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
}
pub(crate) trait BinaryOp {
@ -219,6 +230,24 @@ impl UnaryOp for Gelu {
0
}
const KERNEL: &'static str = "ugelu";
#[cfg(feature = "mkl")]
const F32_VEC: bool = true;
#[cfg(feature = "mkl")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::mkl::vs_gelu(xs, ys)
}
#[cfg(feature = "mkl")]
const F64_VEC: bool = true;
#[cfg(feature = "mkl")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_gelu(xs, ys)
}
}
impl UnaryOp for Relu {