Avoid re-encoding the input in the T5 example. (#875)

This commit is contained in:
Laurent Mazare
2023-09-17 11:25:54 +02:00
committed by GitHub
parent eeb54716dd
commit 7f65af1f0d
2 changed files with 17 additions and 4 deletions

View File

@ -171,6 +171,7 @@ fn main() -> Result<()> {
Some(args.temperature) Some(args.temperature)
}; };
let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p); let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
let encoder_output = model.encode(&input_token_ids)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
for index in 0.. { for index in 0.. {
@ -184,7 +185,7 @@ fn main() -> Result<()> {
Tensor::new(&[last_token], device)?.unsqueeze(0)? Tensor::new(&[last_token], device)?.unsqueeze(0)?
}; };
let logits = model let logits = model
.forward(&input_token_ids, &decoder_token_ids)? .decode(&decoder_token_ids, &encoder_output)?
.squeeze(0)?; .squeeze(0)?;
let logits = if args.repeat_penalty == 1. { let logits = if args.repeat_penalty == 1. {
logits logits

View File

@ -636,11 +636,18 @@ impl T5ForConditionalGeneration {
}) })
} }
pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> { pub fn encode(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let encoder_output = self.encoder.forward(input_ids, None)?; self.encoder.forward(input_ids, None)
}
pub fn decode(
&mut self,
decoder_input_ids: &Tensor,
encoder_output: &Tensor,
) -> Result<Tensor> {
let decoder_output = self let decoder_output = self
.decoder .decoder
.forward(decoder_input_ids, Some(&encoder_output))?; .forward(decoder_input_ids, Some(encoder_output))?;
let sequence_output = decoder_output let sequence_output = decoder_output
.narrow(1, decoder_output.dim(1)? - 1, 1)? .narrow(1, decoder_output.dim(1)? - 1, 1)?
.squeeze(1)?; .squeeze(1)?;
@ -651,6 +658,11 @@ impl T5ForConditionalGeneration {
Ok(output) Ok(output)
} }
pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
let encoder_output = self.encode(input_ids)?;
self.decode(decoder_input_ids, &encoder_output)
}
pub fn device(&self) -> &Device { pub fn device(&self) -> &Device {
&self.device &self.device
} }