Add support for MADLAD400 (#1285)

* Add support for madlad

* Add support for quantized MADLAD
This commit is contained in:
Juarez Bochi
2023-11-06 23:35:37 -05:00
committed by GitHub
parent a773a4b22b
commit 508f811b93
5 changed files with 44 additions and 6 deletions

View File

@ -63,6 +63,7 @@ pub struct Config {
pub use_cache: bool,
pub pad_token_id: usize,
pub eos_token_id: usize,
pub decoder_start_token_id: Option<usize>,
}
impl Default for Config {
@ -87,6 +88,7 @@ impl Default for Config {
use_cache: true,
pad_token_id: 0,
eos_token_id: 1,
decoder_start_token_id: Some(0),
}
}
}
@ -110,6 +112,7 @@ impl Config {
num_heads: 12,
num_layers: 12,
pad_token_id: 0,
decoder_start_token_id: Some(0),
relative_attention_max_distance: 128,
relative_attention_num_buckets: 32,
use_cache: true,
@ -667,7 +670,12 @@ pub struct T5EncoderModel {
impl T5EncoderModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let shared_vb = if vb.contains_tensor("shared") {
vb.pp("shared")
} else {
vb.pp("decoder").pp("embed_tokens")
};
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
let shared = Arc::new(shared);
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
Ok(Self {
@ -708,7 +716,12 @@ impl T5ForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
assert!(cfg.is_encoder_decoder);
let d_model = cfg.d_model;
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let shared_vb = if vb.contains_tensor("shared") {
vb.pp("shared")
} else {
vb.pp("decoder").pp("embed_tokens")
};
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
let shared = Arc::new(shared);
let mut encoder_cfg = cfg.clone();