Use mkl to accelerate binary ops. (#190)

* Vectorized binary ops with mkl.

* Improve the binary op mkl support.

* Push the support for mkl binary ops.

* Proper vectorization of binary ops.

* Proper mkl'isation when broadcasting binary ops.
This commit is contained in:
Laurent Mazare
2023-07-18 12:04:39 +01:00
committed by GitHub
parent b706f32839
commit ff61a42ad7
4 changed files with 218 additions and 13 deletions

View File

@ -7,6 +7,15 @@ mod ffi {
pub fn vsTanh(n: c_int, a: *const c_float, y: *mut c_float);
pub fn vdTanh(n: c_int, a: *const c_double, y: *mut c_double);
pub fn vsAdd(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdAdd(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsSub(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdSub(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsMul(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdMul(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsDiv(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdDiv(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn sgemm_(
transa: *const c_char,
transb: *const c_char,
@ -190,6 +199,7 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
@ -200,6 +210,7 @@ pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
}
}
#[inline]
pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
@ -209,3 +220,29 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
*y = 0.5 * v * (1.0 + *y)
}
}
macro_rules! binary_op {
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
#[inline]
pub fn $fn_name(a: &[$ty], b: &[$ty], y: &mut [$ty]) {
let a_len = a.len();
let b_len = b.len();
let y_len = y.len();
if a_len != y_len || b_len != y_len {
panic!(
"{} a,b,y len mismatch {a_len} {b_len} {y_len}",
stringify!($fn_name)
);
}
unsafe { ffi::$mkl_name(a_len as i32, a.as_ptr(), b.as_ptr(), y.as_mut_ptr()) }
}
};
}
binary_op!(vs_add, f32, vsAdd);
binary_op!(vd_add, f64, vdAdd);
binary_op!(vs_sub, f32, vsSub);
binary_op!(vd_sub, f64, vdSub);
binary_op!(vs_mul, f32, vsMul);
binary_op!(vd_mul, f64, vdMul);
binary_op!(vs_div, f32, vsDiv);
binary_op!(vd_div, f64, vdDiv);