mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add some optional repeat penalty. (#623)
* Add some optional repeat penalty. * Add the missing files.
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
pub mod generation;
|
||||
pub mod models;
|
||||
pub mod pipelines;
|
||||
pub mod utils;
|
||||
|
18
candle-transformers/src/utils.rs
Normal file
18
candle-transformers/src/utils.rs
Normal file
@ -0,0 +1,18 @@
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
|
||||
let device = logits.device();
|
||||
let mut logits = logits.to_vec1::<f32>()?;
|
||||
let context: std::collections::HashSet<_> = context.iter().collect();
|
||||
for (token_id, logit) in logits.iter_mut().enumerate() {
|
||||
if context.contains(&(token_id as u32)) {
|
||||
if *logit >= 0. {
|
||||
*logit /= penalty
|
||||
} else {
|
||||
*logit *= penalty
|
||||
}
|
||||
}
|
||||
}
|
||||
let logits_len = logits.len();
|
||||
Tensor::from_vec(logits, logits_len, device)
|
||||
}
|
Reference in New Issue
Block a user