mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Mklize more unary ops. (#191)
* Mklize more unary ops. * Even more unary ops.
This commit is contained in:
@ -199,16 +199,63 @@ macro_rules! unary_op {
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
|
||||
impl UnaryOp for $op {
|
||||
const NAME: &'static str = $name;
|
||||
const KERNEL: &'static str = concat!("u", $name);
|
||||
const V: Self = $op;
|
||||
#[inline(always)]
|
||||
fn bf16($a: bf16) -> bf16 {
|
||||
$e
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16($a: f16) -> f16 {
|
||||
$e
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32($a: f32) -> f32 {
|
||||
$e
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64($a: f64) -> f64 {
|
||||
$e
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
todo!("no unary function for u8")
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
todo!("no unary function for u32")
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F32_VEC: bool = true;
|
||||
#[cfg(feature = "mkl")]
|
||||
const F64_VEC: bool = true;
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::mkl::$f32_vec(xs, ys)
|
||||
}
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::mkl::$f64_vec(xs, ys)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
unary_op!(Exp, "exp", v, v.exp());
|
||||
unary_op!(Log, "log", v, v.ln());
|
||||
unary_op!(Sin, "sin", v, v.sin());
|
||||
unary_op!(Cos, "cos", v, v.cos());
|
||||
unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
|
||||
unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
|
||||
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
|
||||
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
|
||||
unary_op!(Abs, "abs", v, v.abs());
|
||||
unary_op!(Neg, "neg", v, -v);
|
||||
unary_op!(Sqr, "sqr", v, v * v);
|
||||
unary_op!(Sqrt, "sqrt", v, v.sqrt());
|
||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||
|
||||
/// `gelu` operation
|
||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||
|
Reference in New Issue
Block a user