mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Fix bug introduced in madlad PR (#1298)
This commit is contained in:
@ -670,7 +670,7 @@ pub struct T5EncoderModel {
|
||||
|
||||
impl T5EncoderModel {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let shared_vb = if vb.contains_tensor("shared") {
|
||||
let shared_vb = if vb.contains_tensor("shared.weight") {
|
||||
vb.pp("shared")
|
||||
} else {
|
||||
vb.pp("decoder").pp("embed_tokens")
|
||||
@ -716,7 +716,7 @@ impl T5ForConditionalGeneration {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
assert!(cfg.is_encoder_decoder);
|
||||
let d_model = cfg.d_model;
|
||||
let shared_vb = if vb.contains_tensor("shared") {
|
||||
let shared_vb = if vb.contains_tensor("shared.weight") {
|
||||
vb.pp("shared")
|
||||
} else {
|
||||
vb.pp("decoder").pp("embed_tokens")
|
||||
|
Reference in New Issue
Block a user