mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Fix some shape issues in falcon. (#95)
* Fix some shape issues. * Use different dtypes.
This commit is contained in:
@ -10,7 +10,10 @@ use clap::Parser;
|
||||
mod model;
|
||||
use model::{Config, Falcon, VarBuilder};
|
||||
|
||||
const DTYPE: DType = DType::F16;
|
||||
#[cfg(feature = "mkl")]
|
||||
const DTYPE: DType = DType::F32;
|
||||
#[cfg(not(feature = "mkl"))]
|
||||
const DTYPE: DType = DType::BF16;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
|
@ -421,13 +421,24 @@ impl FalconAttention {
|
||||
};
|
||||
let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?.to_dtype(query.dtype())?;
|
||||
// TODO: layer_past, use_cache?
|
||||
let query = query.reshape((b_sz, self.num_heads, q_len, head_dim))?;
|
||||
let key = key.reshape((b_sz, self.n_head_kv, q_len, head_dim))?;
|
||||
let value = value.reshape((b_sz, self.n_head_kv, q_len, head_dim))?;
|
||||
let query = query.reshape((b_sz * self.num_heads, q_len, head_dim))?;
|
||||
let key = key.reshape((b_sz * self.n_head_kv, q_len, head_dim))?;
|
||||
let value = value.reshape((b_sz * self.n_head_kv, q_len, head_dim))?;
|
||||
|
||||
let (key, value) = if self.n_head_kv == 1 {
|
||||
(
|
||||
key.broadcast_as(query.dims())?,
|
||||
value.broadcast_as(query.dims())?,
|
||||
)
|
||||
} else {
|
||||
(key, value)
|
||||
};
|
||||
|
||||
// Only handle alibi is None here, and non-flash attention.
|
||||
let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
|
||||
let attention_scores = (attention_scores + mask)?.softmax(D::Minus1)?;
|
||||
let attention_scores = attention_scores
|
||||
.broadcast_add(&mask.squeeze(1)?)?
|
||||
.softmax(D::Minus1)?;
|
||||
let attn_output = attention_scores
|
||||
.matmul(&value)?
|
||||
.reshape((b_sz, self.num_heads, q_len, head_dim))?
|
||||
@ -459,8 +470,8 @@ impl FalconMlp {
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = self.dense_4h_to_h.forward(x)?.gelu()?;
|
||||
let x = self.dense_h_to_4h.forward(&x)?;
|
||||
let x = self.dense_h_to_4h.forward(x)?.gelu()?;
|
||||
let x = self.dense_4h_to_h.forward(&x)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user