diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index f4b5b4b0..4544d828 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -709,8 +709,10 @@ impl T5EncoderModel { pub fn load(vb: VarBuilder, cfg: &Config) -> Result { let shared_vb = if vb.contains_tensor("shared.weight") { vb.pp("shared") - } else { + } else if vb.contains_tensor("decoder.embed_tokens") { vb.pp("decoder").pp("embed_tokens") + } else { + vb.pp("encoder").pp("embed_tokens") }; let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?; let shared = Arc::new(shared);