mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Mklize more unary ops. (#191)
* Mklize more unary ops. * Even more unary ops.
This commit is contained in:
@ -6,6 +6,16 @@ mod ffi {
|
||||
extern "C" {
|
||||
pub fn vsTanh(n: c_int, a: *const c_float, y: *mut c_float);
|
||||
pub fn vdTanh(n: c_int, a: *const c_double, y: *mut c_double);
|
||||
pub fn vsExp(n: c_int, a: *const c_float, y: *mut c_float);
|
||||
pub fn vdExp(n: c_int, a: *const c_double, y: *mut c_double);
|
||||
pub fn vsLn(n: c_int, a: *const c_float, y: *mut c_float);
|
||||
pub fn vdLn(n: c_int, a: *const c_double, y: *mut c_double);
|
||||
pub fn vsSin(n: c_int, a: *const c_float, y: *mut c_float);
|
||||
pub fn vdSin(n: c_int, a: *const c_double, y: *mut c_double);
|
||||
pub fn vsCos(n: c_int, a: *const c_float, y: *mut c_float);
|
||||
pub fn vdCos(n: c_int, a: *const c_double, y: *mut c_double);
|
||||
pub fn vsSqrt(n: c_int, a: *const c_float, y: *mut c_float);
|
||||
pub fn vdSqrt(n: c_int, a: *const c_double, y: *mut c_double);
|
||||
|
||||
pub fn vsAdd(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
|
||||
pub fn vdAdd(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
|
||||
@ -166,6 +176,126 @@ pub unsafe fn hgemm(
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_exp(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vsExp(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_exp(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vdExp(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_ln(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vsLn(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_ln(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vdLn(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_sin(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vsSin(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_sin(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vdSin(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_cos(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vsCos(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_cos(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vdCos(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_sqrt(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vsSqrt(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_sqrt(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vdSqrt(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_sqr(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vsMul(a_len as i32, a.as_ptr(), a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vdMul(a_len as i32, a.as_ptr(), a.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
|
Reference in New Issue
Block a user