Preliminary support for mkl based gelu. (#187)

* Preliminary support for mkl based gelu.

* Add the vectorized function for unary ops.

* Get the mkl specialized gelu to work.
This commit is contained in:
Laurent Mazare
2023-07-18 07:48:48 +01:00
committed by GitHub
parent b8abe2bb4b
commit d73df74cb2
3 changed files with 135 additions and 12 deletions

View File

@ -148,6 +148,48 @@ fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut
}
}
fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
vs: &[T],
layout: &Layout,
mut f: F,
mut f_vec: FV,
) -> Vec<U> {
match layout.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => {
let mut ys: Vec<U> = Vec::with_capacity(len);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(len) };
ys
}
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
let mut result = vec![];
result.reserve(layout.shape().elem_count());
// Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 {
for index in block_start_index {
let v = unsafe { vs.get_unchecked(index) };
result.push(f(*v))
}
} else {
// TODO: Use f_vec here.
for index in block_start_index {
for offset in 0..block_len {
let v = unsafe { vs.get_unchecked(index + offset) };
result.push(f(*v))
}
}
}
result
}
}
}
// This function maps over two strided index sequences.
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
lhs_l: &Layout,
@ -864,20 +906,40 @@ impl BackendStorage for CpuStorage {
fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
match self {
Self::BF16(storage) => {
let data = unary_map(storage, layout, B::bf16);
Ok(Self::BF16(data))
if B::BF16_VEC {
let data = unary_map_vec(storage, layout, B::bf16, B::bf16_vec);
Ok(Self::BF16(data))
} else {
let data = unary_map(storage, layout, B::bf16);
Ok(Self::BF16(data))
}
}
Self::F16(storage) => {
let data = unary_map(storage, layout, B::f16);
Ok(Self::F16(data))
if B::F16_VEC {
let data = unary_map_vec(storage, layout, B::f16, B::f16_vec);
Ok(Self::F16(data))
} else {
let data = unary_map(storage, layout, B::f16);
Ok(Self::F16(data))
}
}
Self::F32(storage) => {
let data = unary_map(storage, layout, B::f32);
Ok(Self::F32(data))
if B::F32_VEC {
let data = unary_map_vec(storage, layout, B::f32, B::f32_vec);
Ok(Self::F32(data))
} else {
let data = unary_map(storage, layout, B::f32);
Ok(Self::F32(data))
}
}
Self::F64(storage) => {
let data = unary_map(storage, layout, B::f64);
Ok(Self::F64(data))
if B::F64_VEC {
let data = unary_map_vec(storage, layout, B::f64, B::f64_vec);
Ok(Self::F64(data))
} else {
let data = unary_map(storage, layout, B::f64);
Ok(Self::F64(data))
}
}
Self::U8(storage) => {
let data = unary_map(storage, layout, B::u8);

View File

@ -1,3 +1,4 @@
#![allow(dead_code)]
use libc::{c_char, c_double, c_float, c_int};
mod ffi {
@ -156,9 +157,8 @@ pub unsafe fn hgemm(
)
}
#[allow(dead_code)]
#[inline]
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
fn vs_tanh(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
@ -167,9 +167,8 @@ pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
unsafe { ffi::vsTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[allow(dead_code)]
#[inline]
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
fn vd_tanh(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
@ -177,3 +176,36 @@ pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
}
unsafe { ffi::vdTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
// The vector functions from mkl can be performed in place by using the same array for input and
// output.
// https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2023-2/vector-mathematical-functions.html
#[inline]
pub fn vs_tanh_inplace(y: &mut [f32]) {
unsafe { ffi::vsTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_tanh_inplace(y: &mut [f64]) {
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
}
vs_tanh_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = 0.5 * v * (1.0 + *y)
}
}
pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
}
vd_tanh_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = 0.5 * v * (1.0 + *y)
}
}

View File

@ -60,6 +60,17 @@ pub(crate) trait UnaryOp {
fn f64(v1: f64) -> f64;
fn u8(v1: u8) -> u8;
fn u32(v1: u32) -> u32;
// There is no very good way to represent optional function in traits so we go for an explicit
// boolean flag to mark the function as existing.
const BF16_VEC: bool = false;
fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {}
const F16_VEC: bool = false;
fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {}
const F32_VEC: bool = false;
fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {}
const F64_VEC: bool = false;
fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
}
pub(crate) trait BinaryOp {
@ -219,6 +230,24 @@ impl UnaryOp for Gelu {
0
}
const KERNEL: &'static str = "ugelu";
#[cfg(feature = "mkl")]
const F32_VEC: bool = true;
#[cfg(feature = "mkl")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::mkl::vs_gelu(xs, ys)
}
#[cfg(feature = "mkl")]
const F64_VEC: bool = true;
#[cfg(feature = "mkl")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_gelu(xs, ys)
}
}
impl UnaryOp for Relu {