mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Fix bug introduced in madlad PR (#1298)
This commit is contained in:
@ -644,7 +644,7 @@ pub struct T5EncoderModel {
|
|||||||
|
|
||||||
impl T5EncoderModel {
|
impl T5EncoderModel {
|
||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let shared_vb = if vb.contains_key("shared") {
|
let shared_vb = if vb.contains_key("shared.weight") {
|
||||||
vb.pp("shared")
|
vb.pp("shared")
|
||||||
} else {
|
} else {
|
||||||
vb.pp("decoder").pp("embed_tokens")
|
vb.pp("decoder").pp("embed_tokens")
|
||||||
@ -690,7 +690,7 @@ impl T5ForConditionalGeneration {
|
|||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
assert!(cfg.is_encoder_decoder);
|
assert!(cfg.is_encoder_decoder);
|
||||||
let d_model = cfg.d_model;
|
let d_model = cfg.d_model;
|
||||||
let shared_vb = if vb.contains_key("shared") {
|
let shared_vb = if vb.contains_key("shared.weight") {
|
||||||
vb.pp("shared")
|
vb.pp("shared")
|
||||||
} else {
|
} else {
|
||||||
vb.pp("decoder").pp("embed_tokens")
|
vb.pp("decoder").pp("embed_tokens")
|
||||||
|
@ -670,7 +670,7 @@ pub struct T5EncoderModel {
|
|||||||
|
|
||||||
impl T5EncoderModel {
|
impl T5EncoderModel {
|
||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let shared_vb = if vb.contains_tensor("shared") {
|
let shared_vb = if vb.contains_tensor("shared.weight") {
|
||||||
vb.pp("shared")
|
vb.pp("shared")
|
||||||
} else {
|
} else {
|
||||||
vb.pp("decoder").pp("embed_tokens")
|
vb.pp("decoder").pp("embed_tokens")
|
||||||
@ -716,7 +716,7 @@ impl T5ForConditionalGeneration {
|
|||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
assert!(cfg.is_encoder_decoder);
|
assert!(cfg.is_encoder_decoder);
|
||||||
let d_model = cfg.d_model;
|
let d_model = cfg.d_model;
|
||||||
let shared_vb = if vb.contains_tensor("shared") {
|
let shared_vb = if vb.contains_tensor("shared.weight") {
|
||||||
vb.pp("shared")
|
vb.pp("shared")
|
||||||
} else {
|
} else {
|
||||||
vb.pp("decoder").pp("embed_tokens")
|
vb.pp("decoder").pp("embed_tokens")
|
||||||
|
Reference in New Issue
Block a user