mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Avoid re-encoding the input in the T5 example. (#875)
This commit is contained in:
@ -636,11 +636,18 @@ impl T5ForConditionalGeneration {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
|
||||
let encoder_output = self.encoder.forward(input_ids, None)?;
|
||||
pub fn encode(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
self.encoder.forward(input_ids, None)
|
||||
}
|
||||
|
||||
pub fn decode(
|
||||
&mut self,
|
||||
decoder_input_ids: &Tensor,
|
||||
encoder_output: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let decoder_output = self
|
||||
.decoder
|
||||
.forward(decoder_input_ids, Some(&encoder_output))?;
|
||||
.forward(decoder_input_ids, Some(encoder_output))?;
|
||||
let sequence_output = decoder_output
|
||||
.narrow(1, decoder_output.dim(1)? - 1, 1)?
|
||||
.squeeze(1)?;
|
||||
@ -651,6 +658,11 @@ impl T5ForConditionalGeneration {
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
|
||||
let encoder_output = self.encode(input_ids)?;
|
||||
self.decode(decoder_input_ids, &encoder_output)
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
|
Reference in New Issue
Block a user