Fix some shape issues in falcon. (#95)

* Fix some shape issues.

* Use different dtypes.
This commit is contained in:
Laurent Mazare
2023-07-06 19:23:54 +01:00
committed by GitHub
parent 4afa461b34
commit 0f679fe42e
2 changed files with 21 additions and 7 deletions

View File

@ -10,7 +10,10 @@ use clap::Parser;
mod model;
use model::{Config, Falcon, VarBuilder};
const DTYPE: DType = DType::F16;
#[cfg(feature = "mkl")]
const DTYPE: DType = DType::F32;
#[cfg(not(feature = "mkl"))]
const DTYPE: DType = DType::BF16;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]

View File

@ -421,13 +421,24 @@ impl FalconAttention {
};
let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?.to_dtype(query.dtype())?;
// TODO: layer_past, use_cache?
let query = query.reshape((b_sz, self.num_heads, q_len, head_dim))?;
let key = key.reshape((b_sz, self.n_head_kv, q_len, head_dim))?;
let value = value.reshape((b_sz, self.n_head_kv, q_len, head_dim))?;
let query = query.reshape((b_sz * self.num_heads, q_len, head_dim))?;
let key = key.reshape((b_sz * self.n_head_kv, q_len, head_dim))?;
let value = value.reshape((b_sz * self.n_head_kv, q_len, head_dim))?;
let (key, value) = if self.n_head_kv == 1 {
(
key.broadcast_as(query.dims())?,
value.broadcast_as(query.dims())?,
)
} else {
(key, value)
};
// Only handle alibi is None here, and non-flash attention.
let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
let attention_scores = (attention_scores + mask)?.softmax(D::Minus1)?;
let attention_scores = attention_scores
.broadcast_add(&mask.squeeze(1)?)?
.softmax(D::Minus1)?;
let attn_output = attention_scores
.matmul(&value)?
.reshape((b_sz, self.num_heads, q_len, head_dim))?
@ -459,8 +470,8 @@ impl FalconMlp {
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = self.dense_4h_to_h.forward(x)?.gelu()?;
let x = self.dense_h_to_4h.forward(&x)?;
let x = self.dense_h_to_4h.forward(x)?.gelu()?;
let x = self.dense_4h_to_h.forward(&x)?;
Ok(x)
}
}