mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Use mkl to accelerate binary ops. (#190)
* Vectorized binary ops with mkl. * Improve the binary op mkl support. * Push the support for mkl binary ops. * Proper vectorization of binary ops. * Proper mkl'isation when broadcasting binary ops.
This commit is contained in:
@ -276,6 +276,119 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||
}
|
||||
}
|
||||
|
||||
fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
mut f: F,
|
||||
mut f_vec: FV,
|
||||
) -> Vec<T> {
|
||||
let el_count = lhs_l.shape().elem_count();
|
||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
||||
Some(ob) if ob.right_broadcast == 1 => {
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
let mut dst_i = 0;
|
||||
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
||||
f_vec(
|
||||
&lhs[src_i..src_i + ob.len],
|
||||
rhs,
|
||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||
);
|
||||
dst_i += ob.len;
|
||||
}
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
Some(ob) => {
|
||||
let mut i_in_block = 0;
|
||||
let mut i_right_broadcast = 0;
|
||||
lhs[o_l1..o_l2]
|
||||
.iter()
|
||||
.map(|&l| {
|
||||
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
|
||||
i_right_broadcast += 1;
|
||||
if i_right_broadcast >= ob.right_broadcast {
|
||||
i_in_block += 1;
|
||||
i_right_broadcast = 0;
|
||||
}
|
||||
if i_in_block >= ob.len {
|
||||
i_in_block = 0
|
||||
}
|
||||
f(l, *r)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
},
|
||||
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
||||
Some(ob) if ob.right_broadcast == 1 => {
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
let mut dst_i = 0;
|
||||
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
||||
f_vec(
|
||||
lhs,
|
||||
&rhs[src_i..src_i + ob.len],
|
||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||
);
|
||||
dst_i += ob.len;
|
||||
}
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
Some(ob) => {
|
||||
let mut i_in_block = 0;
|
||||
let mut i_right_broadcast = 0;
|
||||
rhs[o_r1..o_r2]
|
||||
.iter()
|
||||
.map(|&r| {
|
||||
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
|
||||
i_right_broadcast += 1;
|
||||
if i_right_broadcast >= ob.right_broadcast {
|
||||
i_in_block += 1;
|
||||
i_right_broadcast = 0;
|
||||
}
|
||||
if i_in_block >= ob.len {
|
||||
i_in_block = 0
|
||||
}
|
||||
f(*l, r)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
},
|
||||
_ => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
struct Affine(f64, f64);
|
||||
|
||||
impl Map1 for Affine {
|
||||
@ -961,27 +1074,51 @@ impl BackendStorage for CpuStorage {
|
||||
fn binary_impl<B: BinaryOp>(&self, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
|
||||
match (self, rhs) {
|
||||
(Self::BF16(lhs), Self::BF16(rhs)) => {
|
||||
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16);
|
||||
let data = if B::BF16_VEC {
|
||||
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::bf16, B::bf16_vec)
|
||||
} else {
|
||||
binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16)
|
||||
};
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::F16(lhs), Self::F16(rhs)) => {
|
||||
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f16);
|
||||
let data = if B::F16_VEC {
|
||||
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f16, B::f16_vec)
|
||||
} else {
|
||||
binary_map(lhs_l, rhs_l, lhs, rhs, B::f16)
|
||||
};
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::F32(lhs), Self::F32(rhs)) => {
|
||||
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f32);
|
||||
let data = if B::F32_VEC {
|
||||
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f32, B::f32_vec)
|
||||
} else {
|
||||
binary_map(lhs_l, rhs_l, lhs, rhs, B::f32)
|
||||
};
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F64(lhs), Self::F64(rhs)) => {
|
||||
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f64);
|
||||
let data = if B::F64_VEC {
|
||||
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f64, B::f64_vec)
|
||||
} else {
|
||||
binary_map(lhs_l, rhs_l, lhs, rhs, B::f64)
|
||||
};
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::U32(lhs), Self::U32(rhs)) => {
|
||||
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u32);
|
||||
let data = if B::U32_VEC {
|
||||
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u32, B::u32_vec)
|
||||
} else {
|
||||
binary_map(lhs_l, rhs_l, lhs, rhs, B::u32)
|
||||
};
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::U8(lhs), Self::U8(rhs)) => {
|
||||
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u8);
|
||||
let data = if B::U8_VEC {
|
||||
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec)
|
||||
} else {
|
||||
binary_map(lhs_l, rhs_l, lhs, rhs, B::u8)
|
||||
};
|
||||
Ok(Self::U8(data))
|
||||
}
|
||||
_ => {
|
||||
|
@ -7,6 +7,15 @@ mod ffi {
|
||||
pub fn vsTanh(n: c_int, a: *const c_float, y: *mut c_float);
|
||||
pub fn vdTanh(n: c_int, a: *const c_double, y: *mut c_double);
|
||||
|
||||
pub fn vsAdd(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
|
||||
pub fn vdAdd(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
|
||||
pub fn vsSub(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
|
||||
pub fn vdSub(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
|
||||
pub fn vsMul(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
|
||||
pub fn vdMul(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
|
||||
pub fn vsDiv(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
|
||||
pub fn vdDiv(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
|
||||
|
||||
pub fn sgemm_(
|
||||
transa: *const c_char,
|
||||
transb: *const c_char,
|
||||
@ -190,6 +199,7 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
|
||||
@ -200,6 +210,7 @@ pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
|
||||
@ -209,3 +220,29 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
||||
*y = 0.5 * v * (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
|
||||
#[inline]
|
||||
pub fn $fn_name(a: &[$ty], b: &[$ty], y: &mut [$ty]) {
|
||||
let a_len = a.len();
|
||||
let b_len = b.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len || b_len != y_len {
|
||||
panic!(
|
||||
"{} a,b,y len mismatch {a_len} {b_len} {y_len}",
|
||||
stringify!($fn_name)
|
||||
);
|
||||
}
|
||||
unsafe { ffi::$mkl_name(a_len as i32, a.as_ptr(), b.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
};
|
||||
}
|
||||
binary_op!(vs_add, f32, vsAdd);
|
||||
binary_op!(vd_add, f64, vdAdd);
|
||||
binary_op!(vs_sub, f32, vsSub);
|
||||
binary_op!(vd_sub, f64, vdSub);
|
||||
binary_op!(vs_mul, f32, vsMul);
|
||||
binary_op!(vd_mul, f64, vdMul);
|
||||
binary_op!(vs_div, f32, vsDiv);
|
||||
binary_op!(vd_div, f64, vdDiv);
|
||||
|
@ -83,6 +83,19 @@ pub(crate) trait BinaryOp {
|
||||
fn f64(v1: f64, v2: f64) -> f64;
|
||||
fn u8(v1: u8, v2: u8) -> u8;
|
||||
fn u32(v1: u32, v2: u32) -> u32;
|
||||
|
||||
const BF16_VEC: bool = false;
|
||||
fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {}
|
||||
const F16_VEC: bool = false;
|
||||
fn f16_vec(_xs1: &[f16], _xs2: &[f16], _ys: &mut [f16]) {}
|
||||
const F32_VEC: bool = false;
|
||||
fn f32_vec(_xs1: &[f32], _xs2: &[f32], _ys: &mut [f32]) {}
|
||||
const F64_VEC: bool = false;
|
||||
fn f64_vec(_xs1: &[f64], _xs2: &[f64], _ys: &mut [f64]) {}
|
||||
const U8_VEC: bool = false;
|
||||
fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {}
|
||||
const U32_VEC: bool = false;
|
||||
fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {}
|
||||
}
|
||||
|
||||
pub(crate) struct Add;
|
||||
@ -101,7 +114,7 @@ pub(crate) struct Gelu;
|
||||
pub(crate) struct Relu;
|
||||
|
||||
macro_rules! bin_op {
|
||||
($op:ident, $name: literal, $e: expr) => {
|
||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||
impl BinaryOp for $op {
|
||||
const NAME: &'static str = $name;
|
||||
const KERNEL: &'static str = concat!("b", $name);
|
||||
@ -130,14 +143,29 @@ macro_rules! bin_op {
|
||||
fn u32(v1: u32, v2: u32) -> u32 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F32_VEC: bool = true;
|
||||
#[cfg(feature = "mkl")]
|
||||
const F64_VEC: bool = true;
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
|
||||
crate::mkl::$f32_vec(xs1, xs2, ys)
|
||||
}
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
|
||||
crate::mkl::$f64_vec(xs1, xs2, ys)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
bin_op!(Add, "add", |v1, v2| v1 + v2);
|
||||
bin_op!(Sub, "sub", |v1, v2| v1 - v2);
|
||||
bin_op!(Mul, "mul", |v1, v2| v1 * v2);
|
||||
bin_op!(Div, "div", |v1, v2| v1 / v2);
|
||||
bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add);
|
||||
bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub);
|
||||
bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul);
|
||||
bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
|
||||
|
||||
macro_rules! unary_op {
|
||||
($op: ident, $name: literal, $a: ident, $e: expr) => {
|
||||
|
@ -115,9 +115,12 @@ fn main() -> Result<()> {
|
||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let token_type_ids = token_ids.zeros_like()?;
|
||||
println!("Loaded and encoded {:?}", start.elapsed());
|
||||
for _ in 0..args.n {
|
||||
for idx in 0..args.n {
|
||||
let start = std::time::Instant::now();
|
||||
let _ys = model.forward(&token_ids, &token_type_ids)?;
|
||||
let ys = model.forward(&token_ids, &token_type_ids)?;
|
||||
if idx == 0 {
|
||||
println!("{ys}");
|
||||
}
|
||||
println!("Took {:?}", start.elapsed());
|
||||
}
|
||||
} else {
|
||||
|
Reference in New Issue
Block a user