Add some optional repeat penalty. (#623)

* Add some optional repeat penalty.

* Add the missing files.
This commit is contained in:
Laurent Mazare
2023-08-27 10:48:45 +01:00
committed by GitHub
parent 5320aa6b7d
commit 6e485f2deb
4 changed files with 42 additions and 17 deletions

View File

@ -83,6 +83,14 @@ struct Args {
/// (same structure as huggingface online) /// (same structure as huggingface online)
#[arg(long)] #[arg(long)]
local_weights: Option<String>, local_weights: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.0)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
} }
fn main() -> Result<()> { fn main() -> Result<()> {
@ -200,6 +208,16 @@ fn main() -> Result<()> {
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, index_pos)?; let logits = llama.forward(&input, index_pos)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&tokens[start_at..],
)?
};
index_pos += ctxt.len(); index_pos += ctxt.len();
let next_token = logits_processor.sample(&logits)?; let next_token = logits_processor.sample(&logits)?;

View File

@ -533,22 +533,6 @@ fn print_token(next_token: u32, tokenizer: &Tokenizer) {
} }
} }
fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
let mut logits = logits.to_vec1::<f32>()?;
let context: std::collections::HashSet<_> = context.iter().collect();
for (token_id, logit) in logits.iter_mut().enumerate() {
if context.contains(&(token_id as u32)) {
if *logit >= 0. {
*logit /= penalty
} else {
*logit *= penalty
}
}
}
let logits_len = logits.len();
Tensor::from_vec(logits, logits_len, &Device::Cpu)
}
fn format_size(size_in_bytes: usize) -> String { fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 { if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes) format!("{}B", size_in_bytes)
@ -670,7 +654,11 @@ fn main() -> anyhow::Result<()> {
logits logits
} else { } else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n); let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
apply_repeat_penalty(&logits, args.repeat_penalty, &all_tokens[start_at..])? candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&all_tokens[start_at..],
)?
}; };
next_token = logits_processor.sample(&logits)?; next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token); all_tokens.push(next_token);

View File

@ -1,3 +1,4 @@
pub mod generation; pub mod generation;
pub mod models; pub mod models;
pub mod pipelines; pub mod pipelines;
pub mod utils;

View File

@ -0,0 +1,18 @@
use candle::{Result, Tensor};
pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
let device = logits.device();
let mut logits = logits.to_vec1::<f32>()?;
let context: std::collections::HashSet<_> = context.iter().collect();
for (token_id, logit) in logits.iter_mut().enumerate() {
if context.contains(&(token_id as u32)) {
if *logit >= 0. {
*logit /= penalty
} else {
*logit *= penalty
}
}
}
let logits_len = logits.len();
Tensor::from_vec(logits, logits_len, device)
}