mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Fix the accelerate build (#678)
* Cosmetic changes. * Fix the accelerate build for tanh.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use candle::{Device, IndexOp, Result, Tensor};
|
||||
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
@ -166,7 +166,7 @@ impl MultiHeadAttention {
|
||||
}
|
||||
let w = {
|
||||
let _enter = self.softmax_span.enter();
|
||||
softmax(&qk, candle::D::Minus1)?
|
||||
softmax(&qk, D::Minus1)?
|
||||
};
|
||||
let wv = {
|
||||
let _enter = self.matmul_span.enter();
|
||||
@ -375,8 +375,7 @@ impl TextDecoder {
|
||||
|
||||
pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let x_dims = x.dims();
|
||||
let last = x_dims[x_dims.len() - 1];
|
||||
let last = x.dim(D::Minus1)?;
|
||||
let token_embedding = self.token_embedding.forward(x)?;
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
||||
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
||||
|
Reference in New Issue
Block a user