mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Preliminary support for mkl based gelu. (#187)
* Preliminary support for mkl based gelu. * Add the vectorized function for unary ops. * Get the mkl specialized gelu to work.
This commit is contained in:
@ -148,6 +148,48 @@ fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut
|
||||
}
|
||||
}
|
||||
|
||||
fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
mut f: F,
|
||||
mut f_vec: FV,
|
||||
) -> Vec<U> {
|
||||
match layout.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
let mut ys: Vec<U> = Vec::with_capacity(len);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(len) };
|
||||
ys
|
||||
}
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
let mut result = vec![];
|
||||
result.reserve(layout.shape().elem_count());
|
||||
// Specialize the case where block_len is one to avoid the second loop.
|
||||
if block_len == 1 {
|
||||
for index in block_start_index {
|
||||
let v = unsafe { vs.get_unchecked(index) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
} else {
|
||||
// TODO: Use f_vec here.
|
||||
for index in block_start_index {
|
||||
for offset in 0..block_len {
|
||||
let v = unsafe { vs.get_unchecked(index + offset) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This function maps over two strided index sequences.
|
||||
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||
lhs_l: &Layout,
|
||||
@ -864,20 +906,40 @@ impl BackendStorage for CpuStorage {
|
||||
fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
||||
match self {
|
||||
Self::BF16(storage) => {
|
||||
let data = unary_map(storage, layout, B::bf16);
|
||||
Ok(Self::BF16(data))
|
||||
if B::BF16_VEC {
|
||||
let data = unary_map_vec(storage, layout, B::bf16, B::bf16_vec);
|
||||
Ok(Self::BF16(data))
|
||||
} else {
|
||||
let data = unary_map(storage, layout, B::bf16);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
}
|
||||
Self::F16(storage) => {
|
||||
let data = unary_map(storage, layout, B::f16);
|
||||
Ok(Self::F16(data))
|
||||
if B::F16_VEC {
|
||||
let data = unary_map_vec(storage, layout, B::f16, B::f16_vec);
|
||||
Ok(Self::F16(data))
|
||||
} else {
|
||||
let data = unary_map(storage, layout, B::f16);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
}
|
||||
Self::F32(storage) => {
|
||||
let data = unary_map(storage, layout, B::f32);
|
||||
Ok(Self::F32(data))
|
||||
if B::F32_VEC {
|
||||
let data = unary_map_vec(storage, layout, B::f32, B::f32_vec);
|
||||
Ok(Self::F32(data))
|
||||
} else {
|
||||
let data = unary_map(storage, layout, B::f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
}
|
||||
Self::F64(storage) => {
|
||||
let data = unary_map(storage, layout, B::f64);
|
||||
Ok(Self::F64(data))
|
||||
if B::F64_VEC {
|
||||
let data = unary_map_vec(storage, layout, B::f64, B::f64_vec);
|
||||
Ok(Self::F64(data))
|
||||
} else {
|
||||
let data = unary_map(storage, layout, B::f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
}
|
||||
Self::U8(storage) => {
|
||||
let data = unary_map(storage, layout, B::u8);
|
||||
|
Reference in New Issue
Block a user