From 9874d843f13ef46e34a2cc9167ba0ce1d614caef Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 30 Aug 2023 18:31:14 +0200 Subject: [PATCH] Fix the accelerate build (#678) * Cosmetic changes. * Fix the accelerate build for tanh. --- candle-core/src/accelerate.rs | 22 ++++++++++++++++++++++ candle-examples/examples/whisper/model.rs | 7 +++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/candle-core/src/accelerate.rs b/candle-core/src/accelerate.rs index fc9c2441..87e0ee8d 100644 --- a/candle-core/src/accelerate.rs +++ b/candle-core/src/accelerate.rs @@ -50,6 +50,8 @@ mod ffi { pub fn vvcos(dst: *mut c_double, src: *const c_double, len: *const c_int); pub fn vvlogf(dst: *mut c_float, src: *const c_float, len: *const c_int); pub fn vvlog(dst: *mut c_double, src: *const c_double, len: *const c_int); + pub fn vvtanhf(dst: *mut c_float, src: *const c_float, len: *const c_int); + pub fn vvtanh(dst: *mut c_double, src: *const c_double, len: *const c_int); pub fn vDSP_vaddD( _: *const c_double, @@ -308,6 +310,26 @@ pub fn vd_cos(a: &[f64], y: &mut [f64]) { } unsafe { ffi::vvcos(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } } +#[inline] +pub fn vs_tanh(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvtanhf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + +#[inline] +pub fn vd_tanh(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvtanh(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + #[inline] pub fn vs_ln(a: &[f32], y: &mut [f32]) { let a_len = a.len(); diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index d6bea09a..e58ab2ca 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -1,4 +1,4 @@ -use candle::{Device, IndexOp, Result, Tensor}; +use candle::{Device, IndexOp, Result, Tensor, D}; use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; use serde::Deserialize; @@ -166,7 +166,7 @@ impl MultiHeadAttention { } let w = { let _enter = self.softmax_span.enter(); - softmax(&qk, candle::D::Minus1)? + softmax(&qk, D::Minus1)? }; let wv = { let _enter = self.matmul_span.enter(); @@ -375,8 +375,7 @@ impl TextDecoder { pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result { let _enter = self.span.enter(); - let x_dims = x.dims(); - let last = x_dims[x_dims.len() - 1]; + let last = x.dim(D::Minus1)?; let token_embedding = self.token_embedding.forward(x)?; let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; let mut x = token_embedding.broadcast_add(&positional_embedding)?;