Fix bug introduced in madlad PR (#1298)

This commit is contained in:
Juarez Bochi
2023-11-08 11:55:46 -05:00
committed by GitHub
parent 2feb0b054f
commit f772213e84
2 changed files with 4 additions and 4 deletions

View File

@ -670,7 +670,7 @@ pub struct T5EncoderModel {
impl T5EncoderModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let shared_vb = if vb.contains_tensor("shared") {
let shared_vb = if vb.contains_tensor("shared.weight") {
vb.pp("shared")
} else {
vb.pp("decoder").pp("embed_tokens")
@ -716,7 +716,7 @@ impl T5ForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
assert!(cfg.is_encoder_decoder);
let d_model = cfg.d_model;
let shared_vb = if vb.contains_tensor("shared") {
let shared_vb = if vb.contains_tensor("shared.weight") {
vb.pp("shared")
} else {
vb.pp("decoder").pp("embed_tokens")