Compare commits

...

1 Commits

Author SHA1 Message Date
f7980abbcd Improve the sampling methods. 2024-05-04 10:53:30 +02:00

View File

@ -1,4 +1,4 @@
use candle::{DType, Error, Result, Tensor};
use candle::{DType, Error, IndexOp, Result, Tensor, D};
use rand::{distributions::Distribution, SeedableRng};
#[derive(Clone, PartialEq, Debug)]
@ -73,17 +73,15 @@ impl LogitsProcessor {
}
// top-k sampling samples from the k tokens with the largest probabilities.
fn sample_topk(&mut self, prs: &mut Vec<f32>, top_k: usize) -> Result<u32> {
if top_k >= prs.len() {
self.sample_multinomial(prs)
} else {
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
let (indices, _, _) =
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
let index = self.sample_multinomial(&prs)?;
Ok(indices[index as usize] as u32)
}
fn sample_topk(&mut self, logits: &Tensor, top_k: usize, temperature: f64) -> Result<u32> {
let arg_sort = logits.arg_sort_last_dim(false)?;
let top_k_indices = arg_sort.narrow(candle::D::Minus1, 0, top_k)?;
let top_k_logits = logits.gather(&top_k_indices, D::Minus1)?;
let top_k_logits = (&top_k_logits / temperature)?;
let top_k_prs = candle_nn::ops::softmax_last_dim(&top_k_logits)?;
let top_k_prs = top_k_prs.to_vec1()?;
let index = self.sample_multinomial(&top_k_prs)?;
Ok(top_k_indices.i(index as usize)?.to_vec0::<u32>()?)
}
// top-k sampling samples from the k tokens with the largest probabilities.
@ -137,8 +135,12 @@ impl LogitsProcessor {
}
}
Sampling::TopK { k, temperature } => {
let mut prs = prs(*temperature)?;
self.sample_topk(&mut prs, *k)?
if *k >= logits.dim(D::Minus1)? {
let prs = prs(*temperature)?;
self.sample_multinomial(&prs)?
} else {
self.sample_topk(&logits, *k, *temperature)?
}
}
Sampling::TopKThenTopP { k, p, temperature } => {
let mut prs = prs(*temperature)?;