Use the same default as pytorch for sum. (#164)

This commit is contained in:
Laurent Mazare
2023-07-13 21:32:32 +01:00
committed by GitHub
parent 57be3638d8
commit 2bfa791336
13 changed files with 123 additions and 56 deletions

View File

@ -51,9 +51,9 @@ impl LayerNorm {
};
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
let x = x.to_dtype(internal_dtype)?;
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
let mean_x = (x.sum_keepdim(&[2])? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
let norm_x = (x.sqr()?.sum_keepdim(&[2])? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed
.to_dtype(x_dtype)?