diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index 0df3a001..05507f08 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -22,6 +22,8 @@ struct TextGeneration { device: Device, tokenizer: Tokenizer, logits_processor: LogitsProcessor, + repeat_penalty: f32, + repeat_last_n: usize, } impl TextGeneration { @@ -31,6 +33,8 @@ impl TextGeneration { seed: u64, temp: Option, device: &Device, + repeat_penalty: f32, + repeat_last_n: usize, ) -> Self { let logits_processor = LogitsProcessor::new(seed, temp); Self { @@ -38,6 +42,8 @@ impl TextGeneration { tokenizer, logits_processor, device: device.clone(), + repeat_penalty, + repeat_last_n, } } @@ -63,6 +69,16 @@ impl TextGeneration { let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; let logits = self.model.forward(&input)?; let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; let next_token = self.logits_processor.sample(&logits)?; tokens.push(next_token); @@ -116,6 +132,14 @@ struct Args { #[arg(long, default_value = "refs/pr/43")] revision: String, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.0)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, } fn main() -> Result<()> { @@ -162,7 +186,15 @@ fn main() -> Result<()> { let model = Falcon::load(vb, config)?; println!("loaded the model in {:?}", start.elapsed()); - let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device); + let mut pipeline = TextGeneration::new( + model, + tokenizer, + args.seed, + args.temperature, + &device, + args.repeat_penalty, + args.repeat_last_n, + ); pipeline.run(&args.prompt, args.sample_len)?; Ok(()) }