mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
19 lines
628 B
Rust
19 lines
628 B
Rust
use candle::{Result, Tensor};
|
|
|
|
pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
|
|
let device = logits.device();
|
|
let mut logits = logits.to_vec1::<f32>()?;
|
|
let context: std::collections::HashSet<_> = context.iter().collect();
|
|
for (token_id, logit) in logits.iter_mut().enumerate() {
|
|
if context.contains(&(token_id as u32)) {
|
|
if *logit >= 0. {
|
|
*logit /= penalty
|
|
} else {
|
|
*logit *= penalty
|
|
}
|
|
}
|
|
}
|
|
let logits_len = logits.len();
|
|
Tensor::from_vec(logits, logits_len, device)
|
|
}
|