mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Use F32 for the reduce ops. (#105)
This commit is contained in:
@ -122,12 +122,15 @@ impl LayerNorm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let dtype = x.dtype();
|
||||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||||
|
let x = x.to_dtype(DType::F32)?;
|
||||||
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
|
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
|
||||||
let x = x.broadcast_sub(&mean_x)?;
|
let x = x.broadcast_sub(&mean_x)?;
|
||||||
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
||||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||||
let x = x_normed
|
let x = x_normed
|
||||||
|
.to_dtype(dtype)?
|
||||||
.broadcast_mul(&self.weight)?
|
.broadcast_mul(&self.weight)?
|
||||||
.broadcast_add(&self.bias)?;
|
.broadcast_add(&self.bias)?;
|
||||||
Ok(x)
|
Ok(x)
|
||||||
@ -470,7 +473,9 @@ impl FalconAttention {
|
|||||||
let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
|
let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
|
||||||
let attention_scores = attention_scores
|
let attention_scores = attention_scores
|
||||||
.broadcast_add(&mask.squeeze(1)?)?
|
.broadcast_add(&mask.squeeze(1)?)?
|
||||||
.softmax(D::Minus1)?;
|
.to_dtype(DType::F32)?
|
||||||
|
.softmax(D::Minus1)?
|
||||||
|
.to_dtype(x.dtype())?;
|
||||||
let attn_output = attention_scores
|
let attn_output = attention_scores
|
||||||
.matmul(&value)?
|
.matmul(&value)?
|
||||||
.reshape((b_sz, self.num_heads, seq_len, head_dim))?
|
.reshape((b_sz, self.num_heads, seq_len, head_dim))?
|
||||||
|
Reference in New Issue
Block a user