mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Mklize more unary ops. (#191)
* Mklize more unary ops. * Even more unary ops.
This commit is contained in:
@ -6,6 +6,16 @@ mod ffi {
|
||||
extern "C" {
|
||||
pub fn vsTanh(n: c_int, a: *const c_float, y: *mut c_float);
|
||||
pub fn vdTanh(n: c_int, a: *const c_double, y: *mut c_double);
|
||||
pub fn vsExp(n: c_int, a: *const c_float, y: *mut c_float);
|
||||
pub fn vdExp(n: c_int, a: *const c_double, y: *mut c_double);
|
||||
pub fn vsLn(n: c_int, a: *const c_float, y: *mut c_float);
|
||||
pub fn vdLn(n: c_int, a: *const c_double, y: *mut c_double);
|
||||
pub fn vsSin(n: c_int, a: *const c_float, y: *mut c_float);
|
||||
pub fn vdSin(n: c_int, a: *const c_double, y: *mut c_double);
|
||||
pub fn vsCos(n: c_int, a: *const c_float, y: *mut c_float);
|
||||
pub fn vdCos(n: c_int, a: *const c_double, y: *mut c_double);
|
||||
pub fn vsSqrt(n: c_int, a: *const c_float, y: *mut c_float);
|
||||
pub fn vdSqrt(n: c_int, a: *const c_double, y: *mut c_double);
|
||||
|
||||
pub fn vsAdd(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
|
||||
pub fn vdAdd(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
|
||||
@ -166,6 +176,126 @@ pub unsafe fn hgemm(
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_exp(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vsExp(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_exp(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vdExp(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_ln(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vsLn(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_ln(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vdLn(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_sin(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vsSin(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_sin(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vdSin(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_cos(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vsCos(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_cos(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vdCos(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_sqrt(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vsSqrt(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_sqrt(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vdSqrt(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_sqr(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vsMul(a_len as i32, a.as_ptr(), a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vdMul(a_len as i32, a.as_ptr(), a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
|
@ -199,16 +199,63 @@ macro_rules! unary_op {
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
|
||||
impl UnaryOp for $op {
|
||||
const NAME: &'static str = $name;
|
||||
const KERNEL: &'static str = concat!("u", $name);
|
||||
const V: Self = $op;
|
||||
#[inline(always)]
|
||||
fn bf16($a: bf16) -> bf16 {
|
||||
$e
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16($a: f16) -> f16 {
|
||||
$e
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32($a: f32) -> f32 {
|
||||
$e
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64($a: f64) -> f64 {
|
||||
$e
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
todo!("no unary function for u8")
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
todo!("no unary function for u32")
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F32_VEC: bool = true;
|
||||
#[cfg(feature = "mkl")]
|
||||
const F64_VEC: bool = true;
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::mkl::$f32_vec(xs, ys)
|
||||
}
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::mkl::$f64_vec(xs, ys)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
unary_op!(Exp, "exp", v, v.exp());
|
||||
unary_op!(Log, "log", v, v.ln());
|
||||
unary_op!(Sin, "sin", v, v.sin());
|
||||
unary_op!(Cos, "cos", v, v.cos());
|
||||
unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
|
||||
unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
|
||||
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
|
||||
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
|
||||
unary_op!(Abs, "abs", v, v.abs());
|
||||
unary_op!(Neg, "neg", v, -v);
|
||||
unary_op!(Sqr, "sqr", v, v * v);
|
||||
unary_op!(Sqrt, "sqrt", v, v.sqrt());
|
||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||
|
||||
/// `gelu` operation
|
||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||
|
Reference in New Issue
Block a user