diff --git a/candle-core/src/mkl.rs b/candle-core/src/mkl.rs index 3d71fa6a..9d18d054 100644 --- a/candle-core/src/mkl.rs +++ b/candle-core/src/mkl.rs @@ -6,6 +6,16 @@ mod ffi { extern "C" { pub fn vsTanh(n: c_int, a: *const c_float, y: *mut c_float); pub fn vdTanh(n: c_int, a: *const c_double, y: *mut c_double); + pub fn vsExp(n: c_int, a: *const c_float, y: *mut c_float); + pub fn vdExp(n: c_int, a: *const c_double, y: *mut c_double); + pub fn vsLn(n: c_int, a: *const c_float, y: *mut c_float); + pub fn vdLn(n: c_int, a: *const c_double, y: *mut c_double); + pub fn vsSin(n: c_int, a: *const c_float, y: *mut c_float); + pub fn vdSin(n: c_int, a: *const c_double, y: *mut c_double); + pub fn vsCos(n: c_int, a: *const c_float, y: *mut c_float); + pub fn vdCos(n: c_int, a: *const c_double, y: *mut c_double); + pub fn vsSqrt(n: c_int, a: *const c_float, y: *mut c_float); + pub fn vdSqrt(n: c_int, a: *const c_double, y: *mut c_double); pub fn vsAdd(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float); pub fn vdAdd(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double); @@ -166,6 +176,126 @@ pub unsafe fn hgemm( ) } +#[inline] +pub fn vs_exp(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vsExp(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vd_exp(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vdExp(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vs_ln(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vsLn(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vd_ln(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vdLn(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vs_sin(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vsSin(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vd_sin(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vdSin(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vs_cos(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vsCos(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vd_cos(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vdCos(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vs_sqrt(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vsSqrt(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vd_sqrt(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vdSqrt(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vs_sqr(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vsMul(a_len as i32, a.as_ptr(), a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vd_sqr(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vdMul(a_len as i32, a.as_ptr(), a.as_ptr(), y.as_mut_ptr()) } +} + #[inline] fn vs_tanh(a: &[f32], y: &mut [f32]) { let a_len = a.len(); diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 1344cf50..07ee7670 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -199,16 +199,63 @@ macro_rules! unary_op { } } }; + + ($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => { + impl UnaryOp for $op { + const NAME: &'static str = $name; + const KERNEL: &'static str = concat!("u", $name); + const V: Self = $op; + #[inline(always)] + fn bf16($a: bf16) -> bf16 { + $e + } + #[inline(always)] + fn f16($a: f16) -> f16 { + $e + } + #[inline(always)] + fn f32($a: f32) -> f32 { + $e + } + #[inline(always)] + fn f64($a: f64) -> f64 { + $e + } + #[inline(always)] + fn u8(_: u8) -> u8 { + todo!("no unary function for u8") + } + #[inline(always)] + fn u32(_: u32) -> u32 { + todo!("no unary function for u32") + } + + #[cfg(feature = "mkl")] + const F32_VEC: bool = true; + #[cfg(feature = "mkl")] + const F64_VEC: bool = true; + #[cfg(feature = "mkl")] + #[inline(always)] + fn f32_vec(xs: &[f32], ys: &mut [f32]) { + crate::mkl::$f32_vec(xs, ys) + } + #[cfg(feature = "mkl")] + #[inline(always)] + fn f64_vec(xs: &[f64], ys: &mut [f64]) { + crate::mkl::$f64_vec(xs, ys) + } + } + }; } -unary_op!(Exp, "exp", v, v.exp()); -unary_op!(Log, "log", v, v.ln()); -unary_op!(Sin, "sin", v, v.sin()); -unary_op!(Cos, "cos", v, v.cos()); +unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp); +unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln); +unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin); +unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos); unary_op!(Abs, "abs", v, v.abs()); unary_op!(Neg, "neg", v, -v); -unary_op!(Sqr, "sqr", v, v * v); -unary_op!(Sqrt, "sqrt", v, v.sqrt()); +unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr); +unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt); /// `gelu` operation ///