Add a repeat penality to the llama2-c command line example. (#713)

* Add a repeat penality to the llama2-c command line example.

* Another fix attempt.
This commit is contained in:
Laurent Mazare
2023-09-01 21:38:58 +02:00
committed by GitHub
parent 4d56cef583
commit 2c1df6bba1
2 changed files with 19 additions and 1 deletions

View File

@ -103,6 +103,14 @@ pub struct Args {
/// Tokenizer config file. /// Tokenizer config file.
#[arg(long)] #[arg(long)]
tokenizer: Option<String>, tokenizer: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
} }
impl Args { impl Args {
@ -268,6 +276,16 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = model.forward(&input, index_pos)?; let logits = model.forward(&input, index_pos)?;
let logits = logits.i((0, logits.dim(1)? - 1))?; let logits = logits.i((0, logits.dim(1)? - 1))?;
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
logits
} else {
let start_at = tokens.len().saturating_sub(common_args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
common_args.repeat_penalty,
&tokens[start_at..],
)?
};
index_pos += ctxt.len(); index_pos += ctxt.len();
let next_token = logits_processor.sample(&logits)?; let next_token = logits_processor.sample(&logits)?;

View File

@ -25,7 +25,7 @@ impl Model {
candle_transformers::utils::apply_repeat_penalty( candle_transformers::utils::apply_repeat_penalty(
&logits, &logits,
self.repeat_penalty, self.repeat_penalty,
&tokens[start_at..], &self.tokens[start_at..],
)? )?
}; };