mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add some optional repeat penalty. (#623)
* Add some optional repeat penalty. * Add the missing files.
This commit is contained in:
@ -83,6 +83,14 @@ struct Args {
|
|||||||
/// (same structure as huggingface online)
|
/// (same structure as huggingface online)
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
local_weights: Option<String>,
|
local_weights: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.0)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -200,6 +208,16 @@ fn main() -> Result<()> {
|
|||||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||||
let logits = llama.forward(&input, index_pos)?;
|
let logits = llama.forward(&input, index_pos)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
|
let logits = if args.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
args.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
index_pos += ctxt.len();
|
index_pos += ctxt.len();
|
||||||
|
|
||||||
let next_token = logits_processor.sample(&logits)?;
|
let next_token = logits_processor.sample(&logits)?;
|
||||||
|
@ -533,22 +533,6 @@ fn print_token(next_token: u32, tokenizer: &Tokenizer) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
|
|
||||||
let mut logits = logits.to_vec1::<f32>()?;
|
|
||||||
let context: std::collections::HashSet<_> = context.iter().collect();
|
|
||||||
for (token_id, logit) in logits.iter_mut().enumerate() {
|
|
||||||
if context.contains(&(token_id as u32)) {
|
|
||||||
if *logit >= 0. {
|
|
||||||
*logit /= penalty
|
|
||||||
} else {
|
|
||||||
*logit *= penalty
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let logits_len = logits.len();
|
|
||||||
Tensor::from_vec(logits, logits_len, &Device::Cpu)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn format_size(size_in_bytes: usize) -> String {
|
fn format_size(size_in_bytes: usize) -> String {
|
||||||
if size_in_bytes < 1_000 {
|
if size_in_bytes < 1_000 {
|
||||||
format!("{}B", size_in_bytes)
|
format!("{}B", size_in_bytes)
|
||||||
@ -670,7 +654,11 @@ fn main() -> anyhow::Result<()> {
|
|||||||
logits
|
logits
|
||||||
} else {
|
} else {
|
||||||
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||||
apply_repeat_penalty(&logits, args.repeat_penalty, &all_tokens[start_at..])?
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
args.repeat_penalty,
|
||||||
|
&all_tokens[start_at..],
|
||||||
|
)?
|
||||||
};
|
};
|
||||||
next_token = logits_processor.sample(&logits)?;
|
next_token = logits_processor.sample(&logits)?;
|
||||||
all_tokens.push(next_token);
|
all_tokens.push(next_token);
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
pub mod generation;
|
pub mod generation;
|
||||||
pub mod models;
|
pub mod models;
|
||||||
pub mod pipelines;
|
pub mod pipelines;
|
||||||
|
pub mod utils;
|
||||||
|
18
candle-transformers/src/utils.rs
Normal file
18
candle-transformers/src/utils.rs
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
use candle::{Result, Tensor};
|
||||||
|
|
||||||
|
pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
|
||||||
|
let device = logits.device();
|
||||||
|
let mut logits = logits.to_vec1::<f32>()?;
|
||||||
|
let context: std::collections::HashSet<_> = context.iter().collect();
|
||||||
|
for (token_id, logit) in logits.iter_mut().enumerate() {
|
||||||
|
if context.contains(&(token_id as u32)) {
|
||||||
|
if *logit >= 0. {
|
||||||
|
*logit /= penalty
|
||||||
|
} else {
|
||||||
|
*logit *= penalty
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let logits_len = logits.len();
|
||||||
|
Tensor::from_vec(logits, logits_len, device)
|
||||||
|
}
|
Reference in New Issue
Block a user