mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Simplify the parameters used by sum and sum_keepdim. (#165)
This commit is contained in:
@ -51,9 +51,9 @@ impl LayerNorm {
|
||||
};
|
||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let mean_x = (x.sum_keepdim(&[2])? / hidden_size as f64)?;
|
||||
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(&[2])? / hidden_size as f64)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let x = x_normed
|
||||
.to_dtype(x_dtype)?
|
||||
|
Reference in New Issue
Block a user