Preliminary support for mkl based gelu. (#187)

* Preliminary support for mkl based gelu.

* Add the vectorized function for unary ops.

* Get the mkl specialized gelu to work.
This commit is contained in:
Laurent Mazare
2023-07-18 07:48:48 +01:00
committed by GitHub
parent b8abe2bb4b
commit d73df74cb2
3 changed files with 135 additions and 12 deletions

View File

@ -148,6 +148,48 @@ fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut
}
}
fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
vs: &[T],
layout: &Layout,
mut f: F,
mut f_vec: FV,
) -> Vec<U> {
match layout.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => {
let mut ys: Vec<U> = Vec::with_capacity(len);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(len) };
ys
}
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
let mut result = vec![];
result.reserve(layout.shape().elem_count());
// Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 {
for index in block_start_index {
let v = unsafe { vs.get_unchecked(index) };
result.push(f(*v))
}
} else {
// TODO: Use f_vec here.
for index in block_start_index {
for offset in 0..block_len {
let v = unsafe { vs.get_unchecked(index + offset) };
result.push(f(*v))
}
}
}
result
}
}
}
// This function maps over two strided index sequences.
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
lhs_l: &Layout,
@ -864,20 +906,40 @@ impl BackendStorage for CpuStorage {
fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
match self {
Self::BF16(storage) => {
let data = unary_map(storage, layout, B::bf16);
Ok(Self::BF16(data))
if B::BF16_VEC {
let data = unary_map_vec(storage, layout, B::bf16, B::bf16_vec);
Ok(Self::BF16(data))
} else {
let data = unary_map(storage, layout, B::bf16);
Ok(Self::BF16(data))
}
}
Self::F16(storage) => {
let data = unary_map(storage, layout, B::f16);
Ok(Self::F16(data))
if B::F16_VEC {
let data = unary_map_vec(storage, layout, B::f16, B::f16_vec);
Ok(Self::F16(data))
} else {
let data = unary_map(storage, layout, B::f16);
Ok(Self::F16(data))
}
}
Self::F32(storage) => {
let data = unary_map(storage, layout, B::f32);
Ok(Self::F32(data))
if B::F32_VEC {
let data = unary_map_vec(storage, layout, B::f32, B::f32_vec);
Ok(Self::F32(data))
} else {
let data = unary_map(storage, layout, B::f32);
Ok(Self::F32(data))
}
}
Self::F64(storage) => {
let data = unary_map(storage, layout, B::f64);
Ok(Self::F64(data))
if B::F64_VEC {
let data = unary_map_vec(storage, layout, B::f64, B::f64_vec);
Ok(Self::F64(data))
} else {
let data = unary_map(storage, layout, B::f64);
Ok(Self::F64(data))
}
}
Self::U8(storage) => {
let data = unary_map(storage, layout, B::u8);