From 7f65af1f0dc518ef17623d718f3901f58c3aab06 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 17 Sep 2023 11:25:54 +0200 Subject: [PATCH] Avoid re-encoding the input in the T5 example. (#875) --- candle-examples/examples/t5/main.rs | 3 ++- candle-transformers/src/models/t5.rs | 18 +++++++++++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index 72be23bc..36cbee7c 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -171,6 +171,7 @@ fn main() -> Result<()> { Some(args.temperature) }; let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p); + let encoder_output = model.encode(&input_token_ids)?; let start = std::time::Instant::now(); for index in 0.. { @@ -184,7 +185,7 @@ fn main() -> Result<()> { Tensor::new(&[last_token], device)?.unsqueeze(0)? }; let logits = model - .forward(&input_token_ids, &decoder_token_ids)? + .decode(&decoder_token_ids, &encoder_output)? .squeeze(0)?; let logits = if args.repeat_penalty == 1. { logits diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 8b621f64..2ffc2ee1 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -636,11 +636,18 @@ impl T5ForConditionalGeneration { }) } - pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result { - let encoder_output = self.encoder.forward(input_ids, None)?; + pub fn encode(&mut self, input_ids: &Tensor) -> Result { + self.encoder.forward(input_ids, None) + } + + pub fn decode( + &mut self, + decoder_input_ids: &Tensor, + encoder_output: &Tensor, + ) -> Result { let decoder_output = self .decoder - .forward(decoder_input_ids, Some(&encoder_output))?; + .forward(decoder_input_ids, Some(encoder_output))?; let sequence_output = decoder_output .narrow(1, decoder_output.dim(1)? - 1, 1)? .squeeze(1)?; @@ -651,6 +658,11 @@ impl T5ForConditionalGeneration { Ok(output) } + pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result { + let encoder_output = self.encode(input_ids)?; + self.decode(decoder_input_ids, &encoder_output) + } + pub fn device(&self) -> &Device { &self.device }