diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs index 03f5ef0f..8d03ec44 100644 --- a/candle-transformers/src/models/quantized_t5.rs +++ b/candle-transformers/src/models/quantized_t5.rs @@ -644,7 +644,7 @@ pub struct T5EncoderModel { impl T5EncoderModel { pub fn load(vb: VarBuilder, cfg: &Config) -> Result { - let shared_vb = if vb.contains_key("shared") { + let shared_vb = if vb.contains_key("shared.weight") { vb.pp("shared") } else { vb.pp("decoder").pp("embed_tokens") @@ -690,7 +690,7 @@ impl T5ForConditionalGeneration { pub fn load(vb: VarBuilder, cfg: &Config) -> Result { assert!(cfg.is_encoder_decoder); let d_model = cfg.d_model; - let shared_vb = if vb.contains_key("shared") { + let shared_vb = if vb.contains_key("shared.weight") { vb.pp("shared") } else { vb.pp("decoder").pp("embed_tokens") diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 3069be1c..479a55d9 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -670,7 +670,7 @@ pub struct T5EncoderModel { impl T5EncoderModel { pub fn load(vb: VarBuilder, cfg: &Config) -> Result { - let shared_vb = if vb.contains_tensor("shared") { + let shared_vb = if vb.contains_tensor("shared.weight") { vb.pp("shared") } else { vb.pp("decoder").pp("embed_tokens") @@ -716,7 +716,7 @@ impl T5ForConditionalGeneration { pub fn load(vb: VarBuilder, cfg: &Config) -> Result { assert!(cfg.is_encoder_decoder); let d_model = cfg.d_model; - let shared_vb = if vb.contains_tensor("shared") { + let shared_vb = if vb.contains_tensor("shared.weight") { vb.pp("shared") } else { vb.pp("decoder").pp("embed_tokens")