mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Compare commits
1 Commits
0.9.0-alph
...
improve-sa
Author | SHA1 | Date | |
---|---|---|---|
f7980abbcd |
@ -1,4 +1,4 @@
|
||||
use candle::{DType, Error, Result, Tensor};
|
||||
use candle::{DType, Error, IndexOp, Result, Tensor, D};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
@ -73,17 +73,15 @@ impl LogitsProcessor {
|
||||
}
|
||||
|
||||
// top-k sampling samples from the k tokens with the largest probabilities.
|
||||
fn sample_topk(&mut self, prs: &mut Vec<f32>, top_k: usize) -> Result<u32> {
|
||||
if top_k >= prs.len() {
|
||||
self.sample_multinomial(prs)
|
||||
} else {
|
||||
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
|
||||
let (indices, _, _) =
|
||||
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
|
||||
let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
|
||||
let index = self.sample_multinomial(&prs)?;
|
||||
Ok(indices[index as usize] as u32)
|
||||
}
|
||||
fn sample_topk(&mut self, logits: &Tensor, top_k: usize, temperature: f64) -> Result<u32> {
|
||||
let arg_sort = logits.arg_sort_last_dim(false)?;
|
||||
let top_k_indices = arg_sort.narrow(candle::D::Minus1, 0, top_k)?;
|
||||
let top_k_logits = logits.gather(&top_k_indices, D::Minus1)?;
|
||||
let top_k_logits = (&top_k_logits / temperature)?;
|
||||
let top_k_prs = candle_nn::ops::softmax_last_dim(&top_k_logits)?;
|
||||
let top_k_prs = top_k_prs.to_vec1()?;
|
||||
let index = self.sample_multinomial(&top_k_prs)?;
|
||||
Ok(top_k_indices.i(index as usize)?.to_vec0::<u32>()?)
|
||||
}
|
||||
|
||||
// top-k sampling samples from the k tokens with the largest probabilities.
|
||||
@ -137,8 +135,12 @@ impl LogitsProcessor {
|
||||
}
|
||||
}
|
||||
Sampling::TopK { k, temperature } => {
|
||||
let mut prs = prs(*temperature)?;
|
||||
self.sample_topk(&mut prs, *k)?
|
||||
if *k >= logits.dim(D::Minus1)? {
|
||||
let prs = prs(*temperature)?;
|
||||
self.sample_multinomial(&prs)?
|
||||
} else {
|
||||
self.sample_topk(&logits, *k, *temperature)?
|
||||
}
|
||||
}
|
||||
Sampling::TopKThenTopP { k, p, temperature } => {
|
||||
let mut prs = prs(*temperature)?;
|
||||
|
Reference in New Issue
Block a user