mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add support for MADLAD400 (#1285)
* Add support for madlad * Add support for quantized MADLAD
This commit is contained in:
@ -173,7 +173,11 @@ fn main() -> Result<()> {
|
|||||||
.to_vec();
|
.to_vec();
|
||||||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||||
let mut model = builder.build_model()?;
|
let mut model = builder.build_model()?;
|
||||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
let mut output_token_ids = [builder
|
||||||
|
.config
|
||||||
|
.decoder_start_token_id
|
||||||
|
.unwrap_or(builder.config.pad_token_id) as u32]
|
||||||
|
.to_vec();
|
||||||
let temperature = if args.temperature <= 0. {
|
let temperature = if args.temperature <= 0. {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
|
@ -172,7 +172,12 @@ fn main() -> Result<()> {
|
|||||||
println!("Took {:?}", start.elapsed());
|
println!("Took {:?}", start.elapsed());
|
||||||
} else {
|
} else {
|
||||||
let mut model = builder.build_conditional_generation()?;
|
let mut model = builder.build_conditional_generation()?;
|
||||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
let mut output_token_ids = [builder
|
||||||
|
.config
|
||||||
|
.decoder_start_token_id
|
||||||
|
.unwrap_or(builder.config.pad_token_id)
|
||||||
|
as u32]
|
||||||
|
.to_vec();
|
||||||
if let Some(decoder_prompt) = &args.decoder_prompt {
|
if let Some(decoder_prompt) = &args.decoder_prompt {
|
||||||
print!("{decoder_prompt}");
|
print!("{decoder_prompt}");
|
||||||
output_token_ids.extend(
|
output_token_ids.extend(
|
||||||
|
@ -65,6 +65,7 @@ pub struct Config {
|
|||||||
pub use_cache: bool,
|
pub use_cache: bool,
|
||||||
pub pad_token_id: usize,
|
pub pad_token_id: usize,
|
||||||
pub eos_token_id: usize,
|
pub eos_token_id: usize,
|
||||||
|
pub decoder_start_token_id: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Config {
|
impl Default for Config {
|
||||||
@ -89,6 +90,7 @@ impl Default for Config {
|
|||||||
use_cache: true,
|
use_cache: true,
|
||||||
pad_token_id: 0,
|
pad_token_id: 0,
|
||||||
eos_token_id: 1,
|
eos_token_id: 1,
|
||||||
|
decoder_start_token_id: Some(0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -642,7 +644,12 @@ pub struct T5EncoderModel {
|
|||||||
|
|
||||||
impl T5EncoderModel {
|
impl T5EncoderModel {
|
||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
let shared_vb = if vb.contains_key("shared") {
|
||||||
|
vb.pp("shared")
|
||||||
|
} else {
|
||||||
|
vb.pp("decoder").pp("embed_tokens")
|
||||||
|
};
|
||||||
|
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
|
||||||
let shared = Arc::new(shared);
|
let shared = Arc::new(shared);
|
||||||
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
|
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -683,7 +690,12 @@ impl T5ForConditionalGeneration {
|
|||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
assert!(cfg.is_encoder_decoder);
|
assert!(cfg.is_encoder_decoder);
|
||||||
let d_model = cfg.d_model;
|
let d_model = cfg.d_model;
|
||||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
let shared_vb = if vb.contains_key("shared") {
|
||||||
|
vb.pp("shared")
|
||||||
|
} else {
|
||||||
|
vb.pp("decoder").pp("embed_tokens")
|
||||||
|
};
|
||||||
|
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
|
||||||
let shared = Arc::new(shared);
|
let shared = Arc::new(shared);
|
||||||
|
|
||||||
let mut encoder_cfg = cfg.clone();
|
let mut encoder_cfg = cfg.clone();
|
||||||
|
@ -63,6 +63,7 @@ pub struct Config {
|
|||||||
pub use_cache: bool,
|
pub use_cache: bool,
|
||||||
pub pad_token_id: usize,
|
pub pad_token_id: usize,
|
||||||
pub eos_token_id: usize,
|
pub eos_token_id: usize,
|
||||||
|
pub decoder_start_token_id: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Config {
|
impl Default for Config {
|
||||||
@ -87,6 +88,7 @@ impl Default for Config {
|
|||||||
use_cache: true,
|
use_cache: true,
|
||||||
pad_token_id: 0,
|
pad_token_id: 0,
|
||||||
eos_token_id: 1,
|
eos_token_id: 1,
|
||||||
|
decoder_start_token_id: Some(0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -110,6 +112,7 @@ impl Config {
|
|||||||
num_heads: 12,
|
num_heads: 12,
|
||||||
num_layers: 12,
|
num_layers: 12,
|
||||||
pad_token_id: 0,
|
pad_token_id: 0,
|
||||||
|
decoder_start_token_id: Some(0),
|
||||||
relative_attention_max_distance: 128,
|
relative_attention_max_distance: 128,
|
||||||
relative_attention_num_buckets: 32,
|
relative_attention_num_buckets: 32,
|
||||||
use_cache: true,
|
use_cache: true,
|
||||||
@ -667,7 +670,12 @@ pub struct T5EncoderModel {
|
|||||||
|
|
||||||
impl T5EncoderModel {
|
impl T5EncoderModel {
|
||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
let shared_vb = if vb.contains_tensor("shared") {
|
||||||
|
vb.pp("shared")
|
||||||
|
} else {
|
||||||
|
vb.pp("decoder").pp("embed_tokens")
|
||||||
|
};
|
||||||
|
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
|
||||||
let shared = Arc::new(shared);
|
let shared = Arc::new(shared);
|
||||||
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
|
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -708,7 +716,12 @@ impl T5ForConditionalGeneration {
|
|||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
assert!(cfg.is_encoder_decoder);
|
assert!(cfg.is_encoder_decoder);
|
||||||
let d_model = cfg.d_model;
|
let d_model = cfg.d_model;
|
||||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
let shared_vb = if vb.contains_tensor("shared") {
|
||||||
|
vb.pp("shared")
|
||||||
|
} else {
|
||||||
|
vb.pp("decoder").pp("embed_tokens")
|
||||||
|
};
|
||||||
|
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
|
||||||
let shared = Arc::new(shared);
|
let shared = Arc::new(shared);
|
||||||
|
|
||||||
let mut encoder_cfg = cfg.clone();
|
let mut encoder_cfg = cfg.clone();
|
||||||
|
@ -90,4 +90,8 @@ impl VarBuilder {
|
|||||||
pub fn device(&self) -> &Device {
|
pub fn device(&self) -> &Device {
|
||||||
&self.device
|
&self.device
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn contains_key(&self, key: &str) -> bool {
|
||||||
|
self.data.contains_key(key)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user