mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Simplify the parameters used by sum and sum_keepdim. (#165)
This commit is contained in:
@ -98,7 +98,7 @@ impl T5LayerNorm {
|
||||
let dtype = xs.dtype();
|
||||
let xs_f32 = xs.to_dtype(DType::F32)?;
|
||||
let xs2_f32 = (&xs_f32 * &xs_f32)?;
|
||||
let sum_xs2_f32 = xs2_f32.sum_keepdim(&[xs.rank() - 1])?;
|
||||
let sum_xs2_f32 = xs2_f32.sum_keepdim(D::Minus1)?;
|
||||
let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?;
|
||||
let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?;
|
||||
let xs = xs.to_dtype(dtype)?;
|
||||
|
Reference in New Issue
Block a user