From c3a73c583e80aeaad5e369b476a337237b339d03 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 17 Jul 2023 22:06:43 +0100 Subject: [PATCH] Add support for mkl tanh. (#185) --- candle-core/src/mkl.rs | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/candle-core/src/mkl.rs b/candle-core/src/mkl.rs index 8d1aea4e..aabe6edc 100644 --- a/candle-core/src/mkl.rs +++ b/candle-core/src/mkl.rs @@ -3,6 +3,9 @@ use libc::{c_char, c_double, c_float, c_int}; mod ffi { use super::*; extern "C" { + pub fn vsTanh(n: c_int, a: *const c_float, y: *mut c_float); + pub fn vdTanh(n: c_int, a: *const c_double, y: *mut c_double); + pub fn sgemm_( transa: *const c_char, transb: *const c_char, @@ -152,3 +155,25 @@ pub unsafe fn hgemm( &ldc, ) } + +#[allow(dead_code)] +#[inline] +pub fn vs_tanh(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vsTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[allow(dead_code)] +#[inline] +pub fn vd_tanh(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vdTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +}