diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 214ace38..4544d828 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -183,7 +183,7 @@ impl Module for T5LayerNorm { let xs_f32 = xs.to_dtype(DType::F32)?; // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?; - let xs = xs_f32.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?; + let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?; let xs = xs.to_dtype(dtype)?; let xs = xs.broadcast_mul(&self.weight)?; Ok(xs) @@ -709,8 +709,10 @@ impl T5EncoderModel { pub fn load(vb: VarBuilder, cfg: &Config) -> Result { let shared_vb = if vb.contains_tensor("shared.weight") { vb.pp("shared") - } else { + } else if vb.contains_tensor("decoder.embed_tokens") { vb.pp("decoder").pp("embed_tokens") + } else { + vb.pp("encoder").pp("embed_tokens") }; let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?; let shared = Arc::new(shared);