Simplify the parameters used by sum and sum_keepdim. (#165)

This commit is contained in:
Laurent Mazare
2023-07-14 08:22:08 +01:00
committed by GitHub
parent 2bfa791336
commit a2f72edc0d
13 changed files with 179 additions and 98 deletions

View File

@ -51,9 +51,9 @@ impl LayerNorm {
};
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
let x = x.to_dtype(internal_dtype)?;
let mean_x = (x.sum_keepdim(&[2])? / hidden_size as f64)?;
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
let norm_x = (x.sqr()?.sum_keepdim(&[2])? / hidden_size as f64)?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed
.to_dtype(x_dtype)?

View File

@ -30,10 +30,10 @@ fn layer_norm() -> Result<()> {
[4.1742344, 0.5, -3.1742344]
]]
);
let mean = (res.sum_keepdim(&[2])? / 3.0)?;
let mean = (res.sum_keepdim(2)? / 3.0)?;
// The average value should be `b`.
assert_eq!(mean.to_vec3::<f32>()?, [[[0.5], [0.5], [0.5]]]);
let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(&[2])?.sqrt()? / 3.0)?;
let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(2)?.sqrt()? / 3.0)?;
// The standard deviation should be sqrt(`w`).
assert_eq!(
std.to_vec3::<f32>()?,