Use mkl to accelerate binary ops. (#190)

* Vectorized binary ops with mkl.

* Improve the binary op mkl support.

* Push the support for mkl binary ops.

* Proper vectorization of binary ops.

* Proper mkl'isation when broadcasting binary ops.
This commit is contained in:
Laurent Mazare
2023-07-18 12:04:39 +01:00
committed by GitHub
parent b706f32839
commit ff61a42ad7
4 changed files with 218 additions and 13 deletions

View File

@ -276,6 +276,119 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
} }
} }
fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
rhs: &[T],
mut f: F,
mut f_vec: FV,
) -> Vec<T> {
let el_count = lhs_l.shape().elem_count();
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
Some(ob) if ob.right_broadcast == 1 => {
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let mut dst_i = 0;
for src_i in (o_l1..o_l2).step_by(ob.len) {
f_vec(
&lhs[src_i..src_i + ob.len],
rhs,
&mut ys_to_set[dst_i..dst_i + ob.len],
);
dst_i += ob.len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
Some(ob) => {
let mut i_in_block = 0;
let mut i_right_broadcast = 0;
lhs[o_l1..o_l2]
.iter()
.map(|&l| {
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
i_right_broadcast += 1;
if i_right_broadcast >= ob.right_broadcast {
i_in_block += 1;
i_right_broadcast = 0;
}
if i_in_block >= ob.len {
i_in_block = 0
}
f(l, *r)
})
.collect()
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
},
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
Some(ob) if ob.right_broadcast == 1 => {
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let mut dst_i = 0;
for src_i in (o_r1..o_r2).step_by(ob.len) {
f_vec(
lhs,
&rhs[src_i..src_i + ob.len],
&mut ys_to_set[dst_i..dst_i + ob.len],
);
dst_i += ob.len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
Some(ob) => {
let mut i_in_block = 0;
let mut i_right_broadcast = 0;
rhs[o_r1..o_r2]
.iter()
.map(|&r| {
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
i_right_broadcast += 1;
if i_right_broadcast >= ob.right_broadcast {
i_in_block += 1;
i_right_broadcast = 0;
}
if i_in_block >= ob.len {
i_in_block = 0
}
f(*l, r)
})
.collect()
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
},
_ => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
struct Affine(f64, f64); struct Affine(f64, f64);
impl Map1 for Affine { impl Map1 for Affine {
@ -961,27 +1074,51 @@ impl BackendStorage for CpuStorage {
fn binary_impl<B: BinaryOp>(&self, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> { fn binary_impl<B: BinaryOp>(&self, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
match (self, rhs) { match (self, rhs) {
(Self::BF16(lhs), Self::BF16(rhs)) => { (Self::BF16(lhs), Self::BF16(rhs)) => {
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16); let data = if B::BF16_VEC {
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::bf16, B::bf16_vec)
} else {
binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16)
};
Ok(Self::BF16(data)) Ok(Self::BF16(data))
} }
(Self::F16(lhs), Self::F16(rhs)) => { (Self::F16(lhs), Self::F16(rhs)) => {
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f16); let data = if B::F16_VEC {
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f16, B::f16_vec)
} else {
binary_map(lhs_l, rhs_l, lhs, rhs, B::f16)
};
Ok(Self::F16(data)) Ok(Self::F16(data))
} }
(Self::F32(lhs), Self::F32(rhs)) => { (Self::F32(lhs), Self::F32(rhs)) => {
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f32); let data = if B::F32_VEC {
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f32, B::f32_vec)
} else {
binary_map(lhs_l, rhs_l, lhs, rhs, B::f32)
};
Ok(Self::F32(data)) Ok(Self::F32(data))
} }
(Self::F64(lhs), Self::F64(rhs)) => { (Self::F64(lhs), Self::F64(rhs)) => {
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f64); let data = if B::F64_VEC {
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f64, B::f64_vec)
} else {
binary_map(lhs_l, rhs_l, lhs, rhs, B::f64)
};
Ok(Self::F64(data)) Ok(Self::F64(data))
} }
(Self::U32(lhs), Self::U32(rhs)) => { (Self::U32(lhs), Self::U32(rhs)) => {
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u32); let data = if B::U32_VEC {
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u32, B::u32_vec)
} else {
binary_map(lhs_l, rhs_l, lhs, rhs, B::u32)
};
Ok(Self::U32(data)) Ok(Self::U32(data))
} }
(Self::U8(lhs), Self::U8(rhs)) => { (Self::U8(lhs), Self::U8(rhs)) => {
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u8); let data = if B::U8_VEC {
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec)
} else {
binary_map(lhs_l, rhs_l, lhs, rhs, B::u8)
};
Ok(Self::U8(data)) Ok(Self::U8(data))
} }
_ => { _ => {

View File

@ -7,6 +7,15 @@ mod ffi {
pub fn vsTanh(n: c_int, a: *const c_float, y: *mut c_float); pub fn vsTanh(n: c_int, a: *const c_float, y: *mut c_float);
pub fn vdTanh(n: c_int, a: *const c_double, y: *mut c_double); pub fn vdTanh(n: c_int, a: *const c_double, y: *mut c_double);
pub fn vsAdd(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdAdd(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsSub(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdSub(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsMul(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdMul(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsDiv(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdDiv(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn sgemm_( pub fn sgemm_(
transa: *const c_char, transa: *const c_char,
transb: *const c_char, transb: *const c_char,
@ -190,6 +199,7 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) } unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
} }
#[inline]
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) { pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) { for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v) *y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
@ -200,6 +210,7 @@ pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
} }
} }
#[inline]
pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) { pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) { for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v) *y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
@ -209,3 +220,29 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
*y = 0.5 * v * (1.0 + *y) *y = 0.5 * v * (1.0 + *y)
} }
} }
macro_rules! binary_op {
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
#[inline]
pub fn $fn_name(a: &[$ty], b: &[$ty], y: &mut [$ty]) {
let a_len = a.len();
let b_len = b.len();
let y_len = y.len();
if a_len != y_len || b_len != y_len {
panic!(
"{} a,b,y len mismatch {a_len} {b_len} {y_len}",
stringify!($fn_name)
);
}
unsafe { ffi::$mkl_name(a_len as i32, a.as_ptr(), b.as_ptr(), y.as_mut_ptr()) }
}
};
}
binary_op!(vs_add, f32, vsAdd);
binary_op!(vd_add, f64, vdAdd);
binary_op!(vs_sub, f32, vsSub);
binary_op!(vd_sub, f64, vdSub);
binary_op!(vs_mul, f32, vsMul);
binary_op!(vd_mul, f64, vdMul);
binary_op!(vs_div, f32, vsDiv);
binary_op!(vd_div, f64, vdDiv);

View File

@ -83,6 +83,19 @@ pub(crate) trait BinaryOp {
fn f64(v1: f64, v2: f64) -> f64; fn f64(v1: f64, v2: f64) -> f64;
fn u8(v1: u8, v2: u8) -> u8; fn u8(v1: u8, v2: u8) -> u8;
fn u32(v1: u32, v2: u32) -> u32; fn u32(v1: u32, v2: u32) -> u32;
const BF16_VEC: bool = false;
fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {}
const F16_VEC: bool = false;
fn f16_vec(_xs1: &[f16], _xs2: &[f16], _ys: &mut [f16]) {}
const F32_VEC: bool = false;
fn f32_vec(_xs1: &[f32], _xs2: &[f32], _ys: &mut [f32]) {}
const F64_VEC: bool = false;
fn f64_vec(_xs1: &[f64], _xs2: &[f64], _ys: &mut [f64]) {}
const U8_VEC: bool = false;
fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {}
const U32_VEC: bool = false;
fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {}
} }
pub(crate) struct Add; pub(crate) struct Add;
@ -101,7 +114,7 @@ pub(crate) struct Gelu;
pub(crate) struct Relu; pub(crate) struct Relu;
macro_rules! bin_op { macro_rules! bin_op {
($op:ident, $name: literal, $e: expr) => { ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
impl BinaryOp for $op { impl BinaryOp for $op {
const NAME: &'static str = $name; const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("b", $name); const KERNEL: &'static str = concat!("b", $name);
@ -130,14 +143,29 @@ macro_rules! bin_op {
fn u32(v1: u32, v2: u32) -> u32 { fn u32(v1: u32, v2: u32) -> u32 {
$e(v1, v2) $e(v1, v2)
} }
#[cfg(feature = "mkl")]
const F32_VEC: bool = true;
#[cfg(feature = "mkl")]
const F64_VEC: bool = true;
#[cfg(feature = "mkl")]
#[inline(always)]
fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
crate::mkl::$f32_vec(xs1, xs2, ys)
}
#[cfg(feature = "mkl")]
#[inline(always)]
fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
crate::mkl::$f64_vec(xs1, xs2, ys)
}
} }
}; };
} }
bin_op!(Add, "add", |v1, v2| v1 + v2); bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add);
bin_op!(Sub, "sub", |v1, v2| v1 - v2); bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub);
bin_op!(Mul, "mul", |v1, v2| v1 * v2); bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul);
bin_op!(Div, "div", |v1, v2| v1 / v2); bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
macro_rules! unary_op { macro_rules! unary_op {
($op: ident, $name: literal, $a: ident, $e: expr) => { ($op: ident, $name: literal, $a: ident, $e: expr) => {

View File

@ -115,9 +115,12 @@ fn main() -> Result<()> {
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let token_type_ids = token_ids.zeros_like()?; let token_type_ids = token_ids.zeros_like()?;
println!("Loaded and encoded {:?}", start.elapsed()); println!("Loaded and encoded {:?}", start.elapsed());
for _ in 0..args.n { for idx in 0..args.n {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let _ys = model.forward(&token_ids, &token_type_ids)?; let ys = model.forward(&token_ids, &token_type_ids)?;
if idx == 0 {
println!("{ys}");
}
println!("Took {:?}", start.elapsed()); println!("Took {:?}", start.elapsed());
} }
} else { } else {