Mklize more unary ops. (#191)

* Mklize more unary ops.

* Even more unary ops.
This commit is contained in:
Laurent Mazare
2023-07-18 14:32:49 +02:00
committed by GitHub
parent ff61a42ad7
commit 3307db204a
2 changed files with 183 additions and 6 deletions

View File

@ -6,6 +6,16 @@ mod ffi {
extern "C" { extern "C" {
pub fn vsTanh(n: c_int, a: *const c_float, y: *mut c_float); pub fn vsTanh(n: c_int, a: *const c_float, y: *mut c_float);
pub fn vdTanh(n: c_int, a: *const c_double, y: *mut c_double); pub fn vdTanh(n: c_int, a: *const c_double, y: *mut c_double);
pub fn vsExp(n: c_int, a: *const c_float, y: *mut c_float);
pub fn vdExp(n: c_int, a: *const c_double, y: *mut c_double);
pub fn vsLn(n: c_int, a: *const c_float, y: *mut c_float);
pub fn vdLn(n: c_int, a: *const c_double, y: *mut c_double);
pub fn vsSin(n: c_int, a: *const c_float, y: *mut c_float);
pub fn vdSin(n: c_int, a: *const c_double, y: *mut c_double);
pub fn vsCos(n: c_int, a: *const c_float, y: *mut c_float);
pub fn vdCos(n: c_int, a: *const c_double, y: *mut c_double);
pub fn vsSqrt(n: c_int, a: *const c_float, y: *mut c_float);
pub fn vdSqrt(n: c_int, a: *const c_double, y: *mut c_double);
pub fn vsAdd(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float); pub fn vsAdd(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdAdd(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double); pub fn vdAdd(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
@ -166,6 +176,126 @@ pub unsafe fn hgemm(
) )
} }
#[inline]
pub fn vs_exp(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vsExp(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_exp(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vdExp(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_ln(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vsLn(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_ln(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vdLn(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_sin(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vsSin(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_sin(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vdSin(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_cos(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vsCos(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_cos(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vdCos(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_sqrt(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vsSqrt(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_sqrt(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vdSqrt(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_sqr(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vsMul(a_len as i32, a.as_ptr(), a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vdMul(a_len as i32, a.as_ptr(), a.as_ptr(), y.as_mut_ptr()) }
}
#[inline] #[inline]
fn vs_tanh(a: &[f32], y: &mut [f32]) { fn vs_tanh(a: &[f32], y: &mut [f32]) {
let a_len = a.len(); let a_len = a.len();

View File

@ -199,16 +199,63 @@ macro_rules! unary_op {
} }
} }
}; };
($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
impl UnaryOp for $op {
const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("u", $name);
const V: Self = $op;
#[inline(always)]
fn bf16($a: bf16) -> bf16 {
$e
}
#[inline(always)]
fn f16($a: f16) -> f16 {
$e
}
#[inline(always)]
fn f32($a: f32) -> f32 {
$e
}
#[inline(always)]
fn f64($a: f64) -> f64 {
$e
}
#[inline(always)]
fn u8(_: u8) -> u8 {
todo!("no unary function for u8")
}
#[inline(always)]
fn u32(_: u32) -> u32 {
todo!("no unary function for u32")
}
#[cfg(feature = "mkl")]
const F32_VEC: bool = true;
#[cfg(feature = "mkl")]
const F64_VEC: bool = true;
#[cfg(feature = "mkl")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::mkl::$f32_vec(xs, ys)
}
#[cfg(feature = "mkl")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::$f64_vec(xs, ys)
}
}
};
} }
unary_op!(Exp, "exp", v, v.exp()); unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
unary_op!(Log, "log", v, v.ln()); unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
unary_op!(Sin, "sin", v, v.sin()); unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
unary_op!(Cos, "cos", v, v.cos()); unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
unary_op!(Abs, "abs", v, v.abs()); unary_op!(Abs, "abs", v, v.abs());
unary_op!(Neg, "neg", v, -v); unary_op!(Neg, "neg", v, -v);
unary_op!(Sqr, "sqr", v, v * v); unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
unary_op!(Sqrt, "sqrt", v, v.sqrt()); unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
/// `gelu` operation /// `gelu` operation
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions> /// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>