mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add some optional repeat penalty. (#623)
* Add some optional repeat penalty. * Add the missing files.
This commit is contained in:
@ -83,6 +83,14 @@ struct Args {
|
||||
/// (same structure as huggingface online)
|
||||
#[arg(long)]
|
||||
local_weights: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -200,6 +208,16 @@ fn main() -> Result<()> {
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = llama.forward(&input, index_pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
index_pos += ctxt.len();
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
|
@ -533,22 +533,6 @@ fn print_token(next_token: u32, tokenizer: &Tokenizer) {
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
|
||||
let mut logits = logits.to_vec1::<f32>()?;
|
||||
let context: std::collections::HashSet<_> = context.iter().collect();
|
||||
for (token_id, logit) in logits.iter_mut().enumerate() {
|
||||
if context.contains(&(token_id as u32)) {
|
||||
if *logit >= 0. {
|
||||
*logit /= penalty
|
||||
} else {
|
||||
*logit *= penalty
|
||||
}
|
||||
}
|
||||
}
|
||||
let logits_len = logits.len();
|
||||
Tensor::from_vec(logits, logits_len, &Device::Cpu)
|
||||
}
|
||||
|
||||
fn format_size(size_in_bytes: usize) -> String {
|
||||
if size_in_bytes < 1_000 {
|
||||
format!("{}B", size_in_bytes)
|
||||
@ -670,7 +654,11 @@ fn main() -> anyhow::Result<()> {
|
||||
logits
|
||||
} else {
|
||||
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||
apply_repeat_penalty(&logits, args.repeat_penalty, &all_tokens[start_at..])?
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&all_tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
all_tokens.push(next_token);
|
||||
|
Reference in New Issue
Block a user