Tweaks for the T5 example. (#874)

This commit is contained in:
Laurent Mazare
2023-09-17 11:05:15 +02:00
committed by GitHub
parent 1a276b5da7
commit eeb54716dd
2 changed files with 38 additions and 4 deletions

View File

@ -82,7 +82,7 @@ We also provide a some command line based examples using state of the art models
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
using self-supervision (can be used for imagenet classification, depth
evaluation, segmentation).

View File

@ -51,6 +51,22 @@ struct Args {
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
struct T5ModelBuilder {
@ -149,7 +165,12 @@ fn main() -> Result<()> {
} else {
let mut model = builder.build_conditional_generation()?;
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
let mut logits_processor = LogitsProcessor::new(299792458, None, None);
let temperature = if args.temperature <= 0. {
None
} else {
Some(args.temperature)
};
let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
let start = std::time::Instant::now();
for index in 0.. {
@ -162,8 +183,21 @@ fn main() -> Result<()> {
let last_token = *output_token_ids.last().unwrap();
Tensor::new(&[last_token], device)?.unsqueeze(0)?
};
let logits = model.forward(&input_token_ids, &decoder_token_ids)?;
let next_token_id = logits_processor.sample(&logits.flatten_to(1)?)?;
let logits = model
.forward(&input_token_ids, &decoder_token_ids)?
.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token_id = logits_processor.sample(&logits)?;
if next_token_id as usize == builder.config.eos_token_id {
break;
}