mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Use the same default as pytorch for sum. (#164)
This commit is contained in:
@ -30,10 +30,10 @@ fn layer_norm() -> Result<()> {
|
||||
[4.1742344, 0.5, -3.1742344]
|
||||
]]
|
||||
);
|
||||
let mean = (res.sum(&[2])? / 3.0)?;
|
||||
let mean = (res.sum_keepdim(&[2])? / 3.0)?;
|
||||
// The average value should be `b`.
|
||||
assert_eq!(mean.to_vec3::<f32>()?, [[[0.5], [0.5], [0.5]]]);
|
||||
let std = (res.broadcast_sub(&mean)?.sqr()?.sum(&[2])?.sqrt()? / 3.0)?;
|
||||
let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(&[2])?.sqrt()? / 3.0)?;
|
||||
// The standard deviation should be sqrt(`w`).
|
||||
assert_eq!(
|
||||
std.to_vec3::<f32>()?,
|
||||
|
Reference in New Issue
Block a user