mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Fix the accelerate build (#678)
* Cosmetic changes. * Fix the accelerate build for tanh.
This commit is contained in:
@ -50,6 +50,8 @@ mod ffi {
|
||||
pub fn vvcos(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
pub fn vvlogf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||
pub fn vvlog(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
pub fn vvtanhf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||
pub fn vvtanh(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||
|
||||
pub fn vDSP_vaddD(
|
||||
_: *const c_double,
|
||||
@ -308,6 +310,26 @@ pub fn vd_cos(a: &[f64], y: &mut [f64]) {
|
||||
}
|
||||
unsafe { ffi::vvcos(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
#[inline]
|
||||
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvtanhf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||
}
|
||||
unsafe { ffi::vvtanh(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_ln(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
|
Reference in New Issue
Block a user