mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Preliminary support for mkl based gelu. (#187)
* Preliminary support for mkl based gelu. * Add the vectorized function for unary ops. * Get the mkl specialized gelu to work.
This commit is contained in:
@ -148,6 +148,48 @@ fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut
|
||||
}
|
||||
}
|
||||
|
||||
fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
mut f: F,
|
||||
mut f_vec: FV,
|
||||
) -> Vec<U> {
|
||||
match layout.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
let mut ys: Vec<U> = Vec::with_capacity(len);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(len) };
|
||||
ys
|
||||
}
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
let mut result = vec![];
|
||||
result.reserve(layout.shape().elem_count());
|
||||
// Specialize the case where block_len is one to avoid the second loop.
|
||||
if block_len == 1 {
|
||||
for index in block_start_index {
|
||||
let v = unsafe { vs.get_unchecked(index) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
} else {
|
||||
// TODO: Use f_vec here.
|
||||
for index in block_start_index {
|
||||
for offset in 0..block_len {
|
||||
let v = unsafe { vs.get_unchecked(index + offset) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This function maps over two strided index sequences.
|
||||
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||
lhs_l: &Layout,
|
||||
@ -864,20 +906,40 @@ impl BackendStorage for CpuStorage {
|
||||
fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
||||
match self {
|
||||
Self::BF16(storage) => {
|
||||
let data = unary_map(storage, layout, B::bf16);
|
||||
Ok(Self::BF16(data))
|
||||
if B::BF16_VEC {
|
||||
let data = unary_map_vec(storage, layout, B::bf16, B::bf16_vec);
|
||||
Ok(Self::BF16(data))
|
||||
} else {
|
||||
let data = unary_map(storage, layout, B::bf16);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
}
|
||||
Self::F16(storage) => {
|
||||
let data = unary_map(storage, layout, B::f16);
|
||||
Ok(Self::F16(data))
|
||||
if B::F16_VEC {
|
||||
let data = unary_map_vec(storage, layout, B::f16, B::f16_vec);
|
||||
Ok(Self::F16(data))
|
||||
} else {
|
||||
let data = unary_map(storage, layout, B::f16);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
}
|
||||
Self::F32(storage) => {
|
||||
let data = unary_map(storage, layout, B::f32);
|
||||
Ok(Self::F32(data))
|
||||
if B::F32_VEC {
|
||||
let data = unary_map_vec(storage, layout, B::f32, B::f32_vec);
|
||||
Ok(Self::F32(data))
|
||||
} else {
|
||||
let data = unary_map(storage, layout, B::f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
}
|
||||
Self::F64(storage) => {
|
||||
let data = unary_map(storage, layout, B::f64);
|
||||
Ok(Self::F64(data))
|
||||
if B::F64_VEC {
|
||||
let data = unary_map_vec(storage, layout, B::f64, B::f64_vec);
|
||||
Ok(Self::F64(data))
|
||||
} else {
|
||||
let data = unary_map(storage, layout, B::f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
}
|
||||
Self::U8(storage) => {
|
||||
let data = unary_map(storage, layout, B::u8);
|
||||
|
@ -1,3 +1,4 @@
|
||||
#![allow(dead_code)]
|
||||
use libc::{c_char, c_double, c_float, c_int};
|
||||
|
||||
mod ffi {
|
||||
@ -156,9 +157,8 @@ pub unsafe fn hgemm(
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[inline]
|
||||
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
@ -167,9 +167,8 @@ pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
unsafe { ffi::vsTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[inline]
|
||||
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||
fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
@ -177,3 +176,36 @@ pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||
}
|
||||
unsafe { ffi::vdTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
// The vector functions from mkl can be performed in place by using the same array for input and
|
||||
// output.
|
||||
// https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2023-2/vector-mathematical-functions.html
|
||||
#[inline]
|
||||
pub fn vs_tanh_inplace(y: &mut [f32]) {
|
||||
unsafe { ffi::vsTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_tanh_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
|
||||
}
|
||||
vs_tanh_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = 0.5 * v * (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
|
||||
}
|
||||
vd_tanh_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = 0.5 * v * (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
@ -60,6 +60,17 @@ pub(crate) trait UnaryOp {
|
||||
fn f64(v1: f64) -> f64;
|
||||
fn u8(v1: u8) -> u8;
|
||||
fn u32(v1: u32) -> u32;
|
||||
|
||||
// There is no very good way to represent optional function in traits so we go for an explicit
|
||||
// boolean flag to mark the function as existing.
|
||||
const BF16_VEC: bool = false;
|
||||
fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {}
|
||||
const F16_VEC: bool = false;
|
||||
fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {}
|
||||
const F32_VEC: bool = false;
|
||||
fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {}
|
||||
const F64_VEC: bool = false;
|
||||
fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
|
||||
}
|
||||
|
||||
pub(crate) trait BinaryOp {
|
||||
@ -219,6 +230,24 @@ impl UnaryOp for Gelu {
|
||||
0
|
||||
}
|
||||
const KERNEL: &'static str = "ugelu";
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F32_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::mkl::vs_gelu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F64_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::mkl::vd_gelu(xs, ys)
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOp for Relu {
|
||||
|
Reference in New Issue
Block a user