mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add some optional repeat penalty. (#535)
This commit is contained in:
@ -325,6 +325,14 @@ struct Args {
|
|||||||
/// Display the token for the specified prompt.
|
/// Display the token for the specified prompt.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
verbose_prompt: bool,
|
verbose_prompt: bool,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.0)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
@ -378,6 +386,22 @@ fn print_token(next_token: u32, tokenizer: &Tokenizer) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
|
||||||
|
let mut logits = logits.to_vec1::<f32>()?;
|
||||||
|
let context: std::collections::HashSet<_> = context.iter().collect();
|
||||||
|
for (token_id, logit) in logits.iter_mut().enumerate() {
|
||||||
|
if context.contains(&(token_id as u32)) {
|
||||||
|
if *logit >= 0. {
|
||||||
|
*logit /= penalty
|
||||||
|
} else {
|
||||||
|
*logit *= penalty
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let logits_len = logits.len();
|
||||||
|
Tensor::from_vec(logits, logits_len, &Device::Cpu)
|
||||||
|
}
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
@ -439,6 +463,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let prompt_tokens = tokens.get_ids().to_vec();
|
let prompt_tokens = tokens.get_ids().to_vec();
|
||||||
|
let mut all_tokens = vec![];
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||||
|
|
||||||
print!("{prompt}");
|
print!("{prompt}");
|
||||||
@ -451,6 +476,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
logits_processor.sample(&logits)?
|
logits_processor.sample(&logits)?
|
||||||
};
|
};
|
||||||
let prompt_dt = start_prompt_processing.elapsed();
|
let prompt_dt = start_prompt_processing.elapsed();
|
||||||
|
all_tokens.push(next_token);
|
||||||
print_token(next_token, &tokenizer);
|
print_token(next_token, &tokenizer);
|
||||||
|
|
||||||
let to_sample = args.sample_len.saturating_sub(1);
|
let to_sample = args.sample_len.saturating_sub(1);
|
||||||
@ -459,7 +485,14 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
|
let logits = if args.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||||
|
apply_repeat_penalty(&logits, args.repeat_penalty, &all_tokens[start_at..])?
|
||||||
|
};
|
||||||
next_token = logits_processor.sample(&logits)?;
|
next_token = logits_processor.sample(&logits)?;
|
||||||
|
all_tokens.push(next_token);
|
||||||
print_token(next_token, &tokenizer);
|
print_token(next_token, &tokenizer);
|
||||||
}
|
}
|
||||||
let dt = start_post_prompt.elapsed();
|
let dt = start_post_prompt.elapsed();
|
||||||
|
Reference in New Issue
Block a user