mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Avoid re-encoding the input in the T5 example. (#875)
This commit is contained in:
@ -171,6 +171,7 @@ fn main() -> Result<()> {
|
|||||||
Some(args.temperature)
|
Some(args.temperature)
|
||||||
};
|
};
|
||||||
let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
|
let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
|
||||||
|
let encoder_output = model.encode(&input_token_ids)?;
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
for index in 0.. {
|
for index in 0.. {
|
||||||
@ -184,7 +185,7 @@ fn main() -> Result<()> {
|
|||||||
Tensor::new(&[last_token], device)?.unsqueeze(0)?
|
Tensor::new(&[last_token], device)?.unsqueeze(0)?
|
||||||
};
|
};
|
||||||
let logits = model
|
let logits = model
|
||||||
.forward(&input_token_ids, &decoder_token_ids)?
|
.decode(&decoder_token_ids, &encoder_output)?
|
||||||
.squeeze(0)?;
|
.squeeze(0)?;
|
||||||
let logits = if args.repeat_penalty == 1. {
|
let logits = if args.repeat_penalty == 1. {
|
||||||
logits
|
logits
|
||||||
|
@ -636,11 +636,18 @@ impl T5ForConditionalGeneration {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
|
pub fn encode(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
let encoder_output = self.encoder.forward(input_ids, None)?;
|
self.encoder.forward(input_ids, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decode(
|
||||||
|
&mut self,
|
||||||
|
decoder_input_ids: &Tensor,
|
||||||
|
encoder_output: &Tensor,
|
||||||
|
) -> Result<Tensor> {
|
||||||
let decoder_output = self
|
let decoder_output = self
|
||||||
.decoder
|
.decoder
|
||||||
.forward(decoder_input_ids, Some(&encoder_output))?;
|
.forward(decoder_input_ids, Some(encoder_output))?;
|
||||||
let sequence_output = decoder_output
|
let sequence_output = decoder_output
|
||||||
.narrow(1, decoder_output.dim(1)? - 1, 1)?
|
.narrow(1, decoder_output.dim(1)? - 1, 1)?
|
||||||
.squeeze(1)?;
|
.squeeze(1)?;
|
||||||
@ -651,6 +658,11 @@ impl T5ForConditionalGeneration {
|
|||||||
Ok(output)
|
Ok(output)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let encoder_output = self.encode(input_ids)?;
|
||||||
|
self.decode(decoder_input_ids, &encoder_output)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn device(&self) -> &Device {
|
pub fn device(&self) -> &Device {
|
||||||
&self.device
|
&self.device
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user