Add support for MADLAD400 (#1285)

* Add support for madlad

* Add support for quantized MADLAD
This commit is contained in:
Juarez Bochi
2023-11-06 23:35:37 -05:00
committed by GitHub
parent a773a4b22b
commit 508f811b93
5 changed files with 44 additions and 6 deletions

View File

@ -173,7 +173,11 @@ fn main() -> Result<()> {
.to_vec();
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let mut model = builder.build_model()?;
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
let mut output_token_ids = [builder
.config
.decoder_start_token_id
.unwrap_or(builder.config.pad_token_id) as u32]
.to_vec();
let temperature = if args.temperature <= 0. {
None
} else {

View File

@ -172,7 +172,12 @@ fn main() -> Result<()> {
println!("Took {:?}", start.elapsed());
} else {
let mut model = builder.build_conditional_generation()?;
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
let mut output_token_ids = [builder
.config
.decoder_start_token_id
.unwrap_or(builder.config.pad_token_id)
as u32]
.to_vec();
if let Some(decoder_prompt) = &args.decoder_prompt {
print!("{decoder_prompt}");
output_token_ids.extend(