diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 2dc46217..e4a7a360 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -325,6 +325,14 @@ struct Args { /// Display the token for the specified prompt. #[arg(long)] verbose_prompt: bool, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.0)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, } impl Args { @@ -378,6 +386,22 @@ fn print_token(next_token: u32, tokenizer: &Tokenizer) { } } +fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result { + let mut logits = logits.to_vec1::()?; + let context: std::collections::HashSet<_> = context.iter().collect(); + for (token_id, logit) in logits.iter_mut().enumerate() { + if context.contains(&(token_id as u32)) { + if *logit >= 0. { + *logit /= penalty + } else { + *logit *= penalty + } + } + } + let logits_len = logits.len(); + Tensor::from_vec(logits, logits_len, &Device::Cpu) +} + fn main() -> anyhow::Result<()> { use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::prelude::*; @@ -439,6 +463,7 @@ fn main() -> anyhow::Result<()> { } let prompt_tokens = tokens.get_ids().to_vec(); + let mut all_tokens = vec![]; let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); print!("{prompt}"); @@ -451,6 +476,7 @@ fn main() -> anyhow::Result<()> { logits_processor.sample(&logits)? }; let prompt_dt = start_prompt_processing.elapsed(); + all_tokens.push(next_token); print_token(next_token, &tokenizer); let to_sample = args.sample_len.saturating_sub(1); @@ -459,7 +485,14 @@ fn main() -> anyhow::Result<()> { let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?; let logits = model.forward(&input, prompt_tokens.len() + index)?; let logits = logits.squeeze(0)?; + let logits = if args.repeat_penalty == 1. { + logits + } else { + let start_at = all_tokens.len().saturating_sub(args.repeat_last_n); + apply_repeat_penalty(&logits, args.repeat_penalty, &all_tokens[start_at..])? + }; next_token = logits_processor.sample(&logits)?; + all_tokens.push(next_token); print_token(next_token, &tokenizer); } let dt = start_post_prompt.elapsed();