mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Avoid re-encoding the input in the T5 example. (#875)
This commit is contained in:
@ -171,6 +171,7 @@ fn main() -> Result<()> {
|
||||
Some(args.temperature)
|
||||
};
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
|
||||
let encoder_output = model.encode(&input_token_ids)?;
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for index in 0.. {
|
||||
@ -184,7 +185,7 @@ fn main() -> Result<()> {
|
||||
Tensor::new(&[last_token], device)?.unsqueeze(0)?
|
||||
};
|
||||
let logits = model
|
||||
.forward(&input_token_ids, &decoder_token_ids)?
|
||||
.decode(&decoder_token_ids, &encoder_output)?
|
||||
.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
|
Reference in New Issue
Block a user