mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Preliminary support for mkl based gelu. (#187)
* Preliminary support for mkl based gelu. * Add the vectorized function for unary ops. * Get the mkl specialized gelu to work.
This commit is contained in:
@ -60,6 +60,17 @@ pub(crate) trait UnaryOp {
|
||||
fn f64(v1: f64) -> f64;
|
||||
fn u8(v1: u8) -> u8;
|
||||
fn u32(v1: u32) -> u32;
|
||||
|
||||
// There is no very good way to represent optional function in traits so we go for an explicit
|
||||
// boolean flag to mark the function as existing.
|
||||
const BF16_VEC: bool = false;
|
||||
fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {}
|
||||
const F16_VEC: bool = false;
|
||||
fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {}
|
||||
const F32_VEC: bool = false;
|
||||
fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {}
|
||||
const F64_VEC: bool = false;
|
||||
fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
|
||||
}
|
||||
|
||||
pub(crate) trait BinaryOp {
|
||||
@ -219,6 +230,24 @@ impl UnaryOp for Gelu {
|
||||
0
|
||||
}
|
||||
const KERNEL: &'static str = "ugelu";
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F32_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::mkl::vs_gelu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F64_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::mkl::vd_gelu(xs, ys)
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOp for Relu {
|
||||
|
Reference in New Issue
Block a user