Add some optional repeat penalty. (#535)

This commit is contained in:
Laurent Mazare
2023-08-21 09:59:13 +01:00
committed by GitHub
parent d70cffdab6
commit 4300864ce9

View File

@ -325,6 +325,14 @@ struct Args {
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.0)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
impl Args {
@ -378,6 +386,22 @@ fn print_token(next_token: u32, tokenizer: &Tokenizer) {
}
}
fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
let mut logits = logits.to_vec1::<f32>()?;
let context: std::collections::HashSet<_> = context.iter().collect();
for (token_id, logit) in logits.iter_mut().enumerate() {
if context.contains(&(token_id as u32)) {
if *logit >= 0. {
*logit /= penalty
} else {
*logit *= penalty
}
}
}
let logits_len = logits.len();
Tensor::from_vec(logits, logits_len, &Device::Cpu)
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
@ -439,6 +463,7 @@ fn main() -> anyhow::Result<()> {
}
let prompt_tokens = tokens.get_ids().to_vec();
let mut all_tokens = vec![];
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
print!("{prompt}");
@ -451,6 +476,7 @@ fn main() -> anyhow::Result<()> {
logits_processor.sample(&logits)?
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
let to_sample = args.sample_len.saturating_sub(1);
@ -459,7 +485,14 @@ fn main() -> anyhow::Result<()> {
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
apply_repeat_penalty(&logits, args.repeat_penalty, &all_tokens[start_at..])?
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
}
let dt = start_post_prompt.elapsed();