mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add support for MADLAD400 (#1285)
* Add support for madlad * Add support for quantized MADLAD
This commit is contained in:
@ -173,7 +173,11 @@ fn main() -> Result<()> {
|
||||
.to_vec();
|
||||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let mut model = builder.build_model()?;
|
||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||
let mut output_token_ids = [builder
|
||||
.config
|
||||
.decoder_start_token_id
|
||||
.unwrap_or(builder.config.pad_token_id) as u32]
|
||||
.to_vec();
|
||||
let temperature = if args.temperature <= 0. {
|
||||
None
|
||||
} else {
|
||||
|
@ -172,7 +172,12 @@ fn main() -> Result<()> {
|
||||
println!("Took {:?}", start.elapsed());
|
||||
} else {
|
||||
let mut model = builder.build_conditional_generation()?;
|
||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||
let mut output_token_ids = [builder
|
||||
.config
|
||||
.decoder_start_token_id
|
||||
.unwrap_or(builder.config.pad_token_id)
|
||||
as u32]
|
||||
.to_vec();
|
||||
if let Some(decoder_prompt) = &args.decoder_prompt {
|
||||
print!("{decoder_prompt}");
|
||||
output_token_ids.extend(
|
||||
|
Reference in New Issue
Block a user