mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add support for mkl tanh. (#185)
This commit is contained in:
@ -3,6 +3,9 @@ use libc::{c_char, c_double, c_float, c_int};
|
|||||||
mod ffi {
|
mod ffi {
|
||||||
use super::*;
|
use super::*;
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
pub fn vsTanh(n: c_int, a: *const c_float, y: *mut c_float);
|
||||||
|
pub fn vdTanh(n: c_int, a: *const c_double, y: *mut c_double);
|
||||||
|
|
||||||
pub fn sgemm_(
|
pub fn sgemm_(
|
||||||
transa: *const c_char,
|
transa: *const c_char,
|
||||||
transb: *const c_char,
|
transb: *const c_char,
|
||||||
@ -152,3 +155,25 @@ pub unsafe fn hgemm(
|
|||||||
&ldc,
|
&ldc,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[inline]
|
||||||
|
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||||
|
let a_len = a.len();
|
||||||
|
let y_len = y.len();
|
||||||
|
if a_len != y_len {
|
||||||
|
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||||
|
}
|
||||||
|
unsafe { ffi::vsTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[inline]
|
||||||
|
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||||
|
let a_len = a.len();
|
||||||
|
let y_len = y.len();
|
||||||
|
if a_len != y_len {
|
||||||
|
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||||
|
}
|
||||||
|
unsafe { ffi::vdTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user