Line-up the llama implementation with the python-transformers one. (#271)

* Line-up the llama implementation with the python-transformers one.

* Also lineup the multiprocess version.
This commit is contained in:
Laurent Mazare
2023-07-28 18:31:28 +01:00
committed by GitHub
parent cb8dd5cd53
commit 7513a5e005
2 changed files with 29 additions and 44 deletions

View File

@ -225,7 +225,7 @@ impl RmsNorm {
let (b_sz, seq_len, hidden_size) = x.shape().dims3()?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?;
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
let size = self.scale.shape().dims1()?;
let scale = self
.scale