mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add a repeat penality to the llama2-c command line example. (#713)
* Add a repeat penality to the llama2-c command line example. * Another fix attempt.
This commit is contained in:
@ -103,6 +103,14 @@ pub struct Args {
|
|||||||
/// Tokenizer config file.
|
/// Tokenizer config file.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tokenizer: Option<String>,
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
@ -268,6 +276,16 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, index_pos)?;
|
let logits = model.forward(&input, index_pos)?;
|
||||||
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
||||||
|
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(common_args.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
common_args.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
index_pos += ctxt.len();
|
index_pos += ctxt.len();
|
||||||
|
|
||||||
let next_token = logits_processor.sample(&logits)?;
|
let next_token = logits_processor.sample(&logits)?;
|
||||||
|
@ -25,7 +25,7 @@ impl Model {
|
|||||||
candle_transformers::utils::apply_repeat_penalty(
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
&logits,
|
&logits,
|
||||||
self.repeat_penalty,
|
self.repeat_penalty,
|
||||||
&tokens[start_at..],
|
&self.tokens[start_at..],
|
||||||
)?
|
)?
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user