mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Include topk sampling in the quantized example. (#2005)
* Include topk sampling in the quantized example. * Also sample with top-k on the mistral side.
This commit is contained in:
@ -7,6 +7,7 @@ pub enum Sampling {
|
||||
All { temperature: f64 },
|
||||
TopK { k: usize, temperature: f64 },
|
||||
TopP { p: f64, temperature: f64 },
|
||||
TopKThenTopP { k: usize, p: f64, temperature: f64 },
|
||||
}
|
||||
|
||||
pub struct LogitsProcessor {
|
||||
@ -77,7 +78,6 @@ impl LogitsProcessor {
|
||||
self.sample_multinomial(prs)
|
||||
} else {
|
||||
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
|
||||
// Sort by descending probability.
|
||||
let (indices, _, _) =
|
||||
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
|
||||
let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
|
||||
@ -86,6 +86,26 @@ impl LogitsProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
// top-k sampling samples from the k tokens with the largest probabilities.
|
||||
// then top-p sampling.
|
||||
fn sample_topk_topp(&mut self, prs: &mut Vec<f32>, top_k: usize, top_p: f32) -> Result<u32> {
|
||||
if top_k >= prs.len() {
|
||||
self.sample_topp(prs, top_p)
|
||||
} else {
|
||||
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
|
||||
let (indices, _, _) =
|
||||
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
|
||||
let mut prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
|
||||
let sum_p = prs.iter().sum::<f32>();
|
||||
let index = if top_p <= 0.0 || top_p >= sum_p {
|
||||
self.sample_multinomial(&prs)?
|
||||
} else {
|
||||
self.sample_topp(&mut prs, top_p)?
|
||||
};
|
||||
Ok(indices[index as usize] as u32)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
|
||||
self.sample_f(logits, |_| {})
|
||||
}
|
||||
@ -120,6 +140,10 @@ impl LogitsProcessor {
|
||||
let mut prs = prs(*temperature)?;
|
||||
self.sample_topk(&mut prs, *k)?
|
||||
}
|
||||
Sampling::TopKThenTopP { k, p, temperature } => {
|
||||
let mut prs = prs(*temperature)?;
|
||||
self.sample_topk_topp(&mut prs, *k, *p as f32)?
|
||||
}
|
||||
};
|
||||
Ok(next_token)
|
||||
}
|
||||
|
Reference in New Issue
Block a user