diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 00bd3033..a466f88f 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -276,6 +276,119 @@ fn binary_map T>( } } +fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])>( + lhs_l: &Layout, + rhs_l: &Layout, + lhs: &[T], + rhs: &[T], + mut f: F, + mut f_vec: FV, +) -> Vec { + let el_count = lhs_l.shape().elem_count(); + match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) { + (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => { + let mut ys: Vec = Vec::with_capacity(el_count); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; + f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set); + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(el_count) }; + ys + } + (Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() { + Some(ob) if ob.right_broadcast == 1 => { + let mut ys: Vec = Vec::with_capacity(el_count); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; + let mut dst_i = 0; + for src_i in (o_l1..o_l2).step_by(ob.len) { + f_vec( + &lhs[src_i..src_i + ob.len], + rhs, + &mut ys_to_set[dst_i..dst_i + ob.len], + ); + dst_i += ob.len; + } + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(el_count) }; + ys + } + Some(ob) => { + let mut i_in_block = 0; + let mut i_right_broadcast = 0; + lhs[o_l1..o_l2] + .iter() + .map(|&l| { + let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) }; + i_right_broadcast += 1; + if i_right_broadcast >= ob.right_broadcast { + i_in_block += 1; + i_right_broadcast = 0; + } + if i_in_block >= ob.len { + i_in_block = 0 + } + f(l, *r) + }) + .collect() + } + None => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + }, + (None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() { + Some(ob) if ob.right_broadcast == 1 => { + let mut ys: Vec = Vec::with_capacity(el_count); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; + let mut dst_i = 0; + for src_i in (o_r1..o_r2).step_by(ob.len) { + f_vec( + lhs, + &rhs[src_i..src_i + ob.len], + &mut ys_to_set[dst_i..dst_i + ob.len], + ); + dst_i += ob.len; + } + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(el_count) }; + ys + } + Some(ob) => { + let mut i_in_block = 0; + let mut i_right_broadcast = 0; + rhs[o_r1..o_r2] + .iter() + .map(|&r| { + let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) }; + i_right_broadcast += 1; + if i_right_broadcast >= ob.right_broadcast { + i_in_block += 1; + i_right_broadcast = 0; + } + if i_in_block >= ob.len { + i_in_block = 0 + } + f(*l, r) + }) + .collect() + } + None => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + }, + _ => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + } +} + struct Affine(f64, f64); impl Map1 for Affine { @@ -961,27 +1074,51 @@ impl BackendStorage for CpuStorage { fn binary_impl(&self, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result { match (self, rhs) { (Self::BF16(lhs), Self::BF16(rhs)) => { - let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16); + let data = if B::BF16_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::bf16, B::bf16_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16) + }; Ok(Self::BF16(data)) } (Self::F16(lhs), Self::F16(rhs)) => { - let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f16); + let data = if B::F16_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f16, B::f16_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::f16) + }; Ok(Self::F16(data)) } (Self::F32(lhs), Self::F32(rhs)) => { - let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f32); + let data = if B::F32_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f32, B::f32_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::f32) + }; Ok(Self::F32(data)) } (Self::F64(lhs), Self::F64(rhs)) => { - let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f64); + let data = if B::F64_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f64, B::f64_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::f64) + }; Ok(Self::F64(data)) } (Self::U32(lhs), Self::U32(rhs)) => { - let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u32); + let data = if B::U32_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u32, B::u32_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::u32) + }; Ok(Self::U32(data)) } (Self::U8(lhs), Self::U8(rhs)) => { - let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u8); + let data = if B::U8_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::u8) + }; Ok(Self::U8(data)) } _ => { diff --git a/candle-core/src/mkl.rs b/candle-core/src/mkl.rs index 60bddcb4..3d71fa6a 100644 --- a/candle-core/src/mkl.rs +++ b/candle-core/src/mkl.rs @@ -7,6 +7,15 @@ mod ffi { pub fn vsTanh(n: c_int, a: *const c_float, y: *mut c_float); pub fn vdTanh(n: c_int, a: *const c_double, y: *mut c_double); + pub fn vsAdd(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float); + pub fn vdAdd(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double); + pub fn vsSub(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float); + pub fn vdSub(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double); + pub fn vsMul(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float); + pub fn vdMul(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double); + pub fn vsDiv(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float); + pub fn vdDiv(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double); + pub fn sgemm_( transa: *const c_char, transb: *const c_char, @@ -190,6 +199,7 @@ pub fn vd_tanh_inplace(y: &mut [f64]) { unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) } } +#[inline] pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) { for (&v, y) in vs.iter().zip(ys.iter_mut()) { *y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v) @@ -200,6 +210,7 @@ pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) { } } +#[inline] pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) { for (&v, y) in vs.iter().zip(ys.iter_mut()) { *y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v) @@ -209,3 +220,29 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) { *y = 0.5 * v * (1.0 + *y) } } + +macro_rules! binary_op { + ($fn_name:ident, $ty:ty, $mkl_name:ident) => { + #[inline] + pub fn $fn_name(a: &[$ty], b: &[$ty], y: &mut [$ty]) { + let a_len = a.len(); + let b_len = b.len(); + let y_len = y.len(); + if a_len != y_len || b_len != y_len { + panic!( + "{} a,b,y len mismatch {a_len} {b_len} {y_len}", + stringify!($fn_name) + ); + } + unsafe { ffi::$mkl_name(a_len as i32, a.as_ptr(), b.as_ptr(), y.as_mut_ptr()) } + } + }; +} +binary_op!(vs_add, f32, vsAdd); +binary_op!(vd_add, f64, vdAdd); +binary_op!(vs_sub, f32, vsSub); +binary_op!(vd_sub, f64, vdSub); +binary_op!(vs_mul, f32, vsMul); +binary_op!(vd_mul, f64, vdMul); +binary_op!(vs_div, f32, vsDiv); +binary_op!(vd_div, f64, vdDiv); diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index ec91a3fc..1344cf50 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -83,6 +83,19 @@ pub(crate) trait BinaryOp { fn f64(v1: f64, v2: f64) -> f64; fn u8(v1: u8, v2: u8) -> u8; fn u32(v1: u32, v2: u32) -> u32; + + const BF16_VEC: bool = false; + fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {} + const F16_VEC: bool = false; + fn f16_vec(_xs1: &[f16], _xs2: &[f16], _ys: &mut [f16]) {} + const F32_VEC: bool = false; + fn f32_vec(_xs1: &[f32], _xs2: &[f32], _ys: &mut [f32]) {} + const F64_VEC: bool = false; + fn f64_vec(_xs1: &[f64], _xs2: &[f64], _ys: &mut [f64]) {} + const U8_VEC: bool = false; + fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {} + const U32_VEC: bool = false; + fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {} } pub(crate) struct Add; @@ -101,7 +114,7 @@ pub(crate) struct Gelu; pub(crate) struct Relu; macro_rules! bin_op { - ($op:ident, $name: literal, $e: expr) => { + ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => { impl BinaryOp for $op { const NAME: &'static str = $name; const KERNEL: &'static str = concat!("b", $name); @@ -130,14 +143,29 @@ macro_rules! bin_op { fn u32(v1: u32, v2: u32) -> u32 { $e(v1, v2) } + + #[cfg(feature = "mkl")] + const F32_VEC: bool = true; + #[cfg(feature = "mkl")] + const F64_VEC: bool = true; + #[cfg(feature = "mkl")] + #[inline(always)] + fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) { + crate::mkl::$f32_vec(xs1, xs2, ys) + } + #[cfg(feature = "mkl")] + #[inline(always)] + fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) { + crate::mkl::$f64_vec(xs1, xs2, ys) + } } }; } -bin_op!(Add, "add", |v1, v2| v1 + v2); -bin_op!(Sub, "sub", |v1, v2| v1 - v2); -bin_op!(Mul, "mul", |v1, v2| v1 * v2); -bin_op!(Div, "div", |v1, v2| v1 / v2); +bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add); +bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub); +bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul); +bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div); macro_rules! unary_op { ($op: ident, $name: literal, $a: ident, $e: expr) => { diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 958b70b1..8ef8b5ce 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -115,9 +115,12 @@ fn main() -> Result<()> { let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; let token_type_ids = token_ids.zeros_like()?; println!("Loaded and encoded {:?}", start.elapsed()); - for _ in 0..args.n { + for idx in 0..args.n { let start = std::time::Instant::now(); - let _ys = model.forward(&token_ids, &token_type_ids)?; + let ys = model.forward(&token_ids, &token_type_ids)?; + if idx == 0 { + println!("{ys}"); + } println!("Took {:?}", start.elapsed()); } } else {