Add some optional repeat penalty. (#623)

* Add some optional repeat penalty.

* Add the missing files.
This commit is contained in:
Laurent Mazare
2023-08-27 10:48:45 +01:00
committed by GitHub
parent 5320aa6b7d
commit 6e485f2deb
4 changed files with 42 additions and 17 deletions

View File

@ -533,22 +533,6 @@ fn print_token(next_token: u32, tokenizer: &Tokenizer) {
}
}
fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
let mut logits = logits.to_vec1::<f32>()?;
let context: std::collections::HashSet<_> = context.iter().collect();
for (token_id, logit) in logits.iter_mut().enumerate() {
if context.contains(&(token_id as u32)) {
if *logit >= 0. {
*logit /= penalty
} else {
*logit *= penalty
}
}
}
let logits_len = logits.len();
Tensor::from_vec(logits, logits_len, &Device::Cpu)
}
fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes)
@ -670,7 +654,11 @@ fn main() -> anyhow::Result<()> {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
apply_repeat_penalty(&logits, args.repeat_penalty, &all_tokens[start_at..])?
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&all_tokens[start_at..],
)?
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);