mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Preliminary support for mkl based gelu. (#187)
* Preliminary support for mkl based gelu. * Add the vectorized function for unary ops. * Get the mkl specialized gelu to work.
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
#![allow(dead_code)]
|
||||
use libc::{c_char, c_double, c_float, c_int};
|
||||
|
||||
mod ffi {
|
||||
@ -156,9 +157,8 @@ pub unsafe fn hgemm(
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[inline]
|
||||
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
@ -167,9 +167,8 @@ pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
unsafe { ffi::vsTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[inline]
|
||||
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||
fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
@ -177,3 +176,36 @@ pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||
}
|
||||
unsafe { ffi::vdTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
// The vector functions from mkl can be performed in place by using the same array for input and
|
||||
// output.
|
||||
// https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2023-2/vector-mathematical-functions.html
|
||||
#[inline]
|
||||
pub fn vs_tanh_inplace(y: &mut [f32]) {
|
||||
unsafe { ffi::vsTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_tanh_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
|
||||
}
|
||||
vs_tanh_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = 0.5 * v * (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
|
||||
}
|
||||
vd_tanh_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = 0.5 * v * (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user