diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 4544d828..8a7a8955 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -183,7 +183,7 @@ impl Module for T5LayerNorm { let xs_f32 = xs.to_dtype(DType::F32)?; // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?; - let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?; + let xs = xs_f32.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?; let xs = xs.to_dtype(dtype)?; let xs = xs.broadcast_mul(&self.weight)?; Ok(xs)