Preliminary support for mkl based gelu. (#187)

* Preliminary support for mkl based gelu.

* Add the vectorized function for unary ops.

* Get the mkl specialized gelu to work.
This commit is contained in:
Laurent Mazare
2023-07-18 07:48:48 +01:00
committed by GitHub
parent b8abe2bb4b
commit d73df74cb2
3 changed files with 135 additions and 12 deletions

View File

@ -1,3 +1,4 @@
#![allow(dead_code)]
use libc::{c_char, c_double, c_float, c_int};
mod ffi {
@ -156,9 +157,8 @@ pub unsafe fn hgemm(
)
}
#[allow(dead_code)]
#[inline]
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
fn vs_tanh(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
@ -167,9 +167,8 @@ pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
unsafe { ffi::vsTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[allow(dead_code)]
#[inline]
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
fn vd_tanh(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
@ -177,3 +176,36 @@ pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
}
unsafe { ffi::vdTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
// The vector functions from mkl can be performed in place by using the same array for input and
// output.
// https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2023-2/vector-mathematical-functions.html
#[inline]
pub fn vs_tanh_inplace(y: &mut [f32]) {
unsafe { ffi::vsTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_tanh_inplace(y: &mut [f64]) {
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
}
vs_tanh_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = 0.5 * v * (1.0 + *y)
}
}
pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
}
vd_tanh_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = 0.5 * v * (1.0 + *y)
}
}