diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 418218b6..e0ade322 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -103,6 +103,14 @@ pub struct Args { /// Tokenizer config file. #[arg(long)] tokenizer: Option, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, } impl Args { @@ -268,6 +276,16 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let logits = model.forward(&input, index_pos)?; let logits = logits.i((0, logits.dim(1)? - 1))?; + let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() { + logits + } else { + let start_at = tokens.len().saturating_sub(common_args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + common_args.repeat_penalty, + &tokens[start_at..], + )? + }; index_pos += ctxt.len(); let next_token = logits_processor.sample(&logits)?; diff --git a/candle-wasm-examples/llama2-c/src/bin/m.rs b/candle-wasm-examples/llama2-c/src/bin/m.rs index a71508ee..d014e38a 100644 --- a/candle-wasm-examples/llama2-c/src/bin/m.rs +++ b/candle-wasm-examples/llama2-c/src/bin/m.rs @@ -25,7 +25,7 @@ impl Model { candle_transformers::utils::apply_repeat_penalty( &logits, self.repeat_penalty, - &tokens[start_at..], + &self.tokens[start_at..], )? };