Add some optional repeat penalty. (#535)

This commit is contained in:
Laurent Mazare
2023-08-21 09:59:13 +01:00
committed by GitHub
parent d70cffdab6
commit 4300864ce9

View File

@ -325,6 +325,14 @@ struct Args {
/// Display the token for the specified prompt. /// Display the token for the specified prompt.
#[arg(long)] #[arg(long)]
verbose_prompt: bool, verbose_prompt: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.0)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
} }
impl Args { impl Args {
@ -378,6 +386,22 @@ fn print_token(next_token: u32, tokenizer: &Tokenizer) {
} }
} }
fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
let mut logits = logits.to_vec1::<f32>()?;
let context: std::collections::HashSet<_> = context.iter().collect();
for (token_id, logit) in logits.iter_mut().enumerate() {
if context.contains(&(token_id as u32)) {
if *logit >= 0. {
*logit /= penalty
} else {
*logit *= penalty
}
}
}
let logits_len = logits.len();
Tensor::from_vec(logits, logits_len, &Device::Cpu)
}
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder; use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*; use tracing_subscriber::prelude::*;
@ -439,6 +463,7 @@ fn main() -> anyhow::Result<()> {
} }
let prompt_tokens = tokens.get_ids().to_vec(); let prompt_tokens = tokens.get_ids().to_vec();
let mut all_tokens = vec![];
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
print!("{prompt}"); print!("{prompt}");
@ -451,6 +476,7 @@ fn main() -> anyhow::Result<()> {
logits_processor.sample(&logits)? logits_processor.sample(&logits)?
}; };
let prompt_dt = start_prompt_processing.elapsed(); let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
print_token(next_token, &tokenizer); print_token(next_token, &tokenizer);
let to_sample = args.sample_len.saturating_sub(1); let to_sample = args.sample_len.saturating_sub(1);
@ -459,7 +485,14 @@ fn main() -> anyhow::Result<()> {
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?; let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?; let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
apply_repeat_penalty(&logits, args.repeat_penalty, &all_tokens[start_at..])?
};
next_token = logits_processor.sample(&logits)?; next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
print_token(next_token, &tokenizer); print_token(next_token, &tokenizer);
} }
let dt = start_post_prompt.elapsed(); let dt = start_post_prompt.elapsed();