mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Use the same default as pytorch for sum. (#164)
This commit is contained in:
@ -95,7 +95,7 @@ impl RmsNorm {
|
||||
// This is a no-op if x's dtype is already f32.
|
||||
let x = x.to_dtype(DType::F32)?;
|
||||
let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
|
||||
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(&[2])? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||
let size = self.scale.shape().r1()?;
|
||||
|
Reference in New Issue
Block a user