mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Fix the accelerate build (#678)
* Cosmetic changes. * Fix the accelerate build for tanh.
This commit is contained in:
@ -50,6 +50,8 @@ mod ffi {
|
|||||||
pub fn vvcos(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
pub fn vvcos(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||||
pub fn vvlogf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
pub fn vvlogf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||||
pub fn vvlog(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
pub fn vvlog(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||||
|
pub fn vvtanhf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||||
|
pub fn vvtanh(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||||
|
|
||||||
pub fn vDSP_vaddD(
|
pub fn vDSP_vaddD(
|
||||||
_: *const c_double,
|
_: *const c_double,
|
||||||
@ -308,6 +310,26 @@ pub fn vd_cos(a: &[f64], y: &mut [f64]) {
|
|||||||
}
|
}
|
||||||
unsafe { ffi::vvcos(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
unsafe { ffi::vvcos(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||||
}
|
}
|
||||||
|
#[inline]
|
||||||
|
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||||
|
let a_len = a.len();
|
||||||
|
let y_len = y.len();
|
||||||
|
if a_len != y_len {
|
||||||
|
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||||
|
}
|
||||||
|
unsafe { ffi::vvtanhf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||||
|
let a_len = a.len();
|
||||||
|
let y_len = y.len();
|
||||||
|
if a_len != y_len {
|
||||||
|
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||||
|
}
|
||||||
|
unsafe { ffi::vvtanh(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||||
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn vs_ln(a: &[f32], y: &mut [f32]) {
|
pub fn vs_ln(a: &[f32], y: &mut [f32]) {
|
||||||
let a_len = a.len();
|
let a_len = a.len();
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use candle::{Device, IndexOp, Result, Tensor};
|
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
@ -166,7 +166,7 @@ impl MultiHeadAttention {
|
|||||||
}
|
}
|
||||||
let w = {
|
let w = {
|
||||||
let _enter = self.softmax_span.enter();
|
let _enter = self.softmax_span.enter();
|
||||||
softmax(&qk, candle::D::Minus1)?
|
softmax(&qk, D::Minus1)?
|
||||||
};
|
};
|
||||||
let wv = {
|
let wv = {
|
||||||
let _enter = self.matmul_span.enter();
|
let _enter = self.matmul_span.enter();
|
||||||
@ -375,8 +375,7 @@ impl TextDecoder {
|
|||||||
|
|
||||||
pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let x_dims = x.dims();
|
let last = x.dim(D::Minus1)?;
|
||||||
let last = x_dims[x_dims.len() - 1];
|
|
||||||
let token_embedding = self.token_embedding.forward(x)?;
|
let token_embedding = self.token_embedding.forward(x)?;
|
||||||
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
||||||
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
||||||
|
Reference in New Issue
Block a user