diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index c250a186..85a9bbe2 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -1,4 +1,4 @@ -use candle::{DType, Error, Result, Tensor}; +use candle::{DType, Error, IndexOp, Result, Tensor, D}; use rand::{distributions::Distribution, SeedableRng}; #[derive(Clone, PartialEq, Debug)] @@ -73,17 +73,15 @@ impl LogitsProcessor { } // top-k sampling samples from the k tokens with the largest probabilities. - fn sample_topk(&mut self, prs: &mut Vec, top_k: usize) -> Result { - if top_k >= prs.len() { - self.sample_multinomial(prs) - } else { - let mut argsort_indices = (0..prs.len()).collect::>(); - let (indices, _, _) = - argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i])); - let prs = indices.iter().map(|&i| prs[i]).collect::>(); - let index = self.sample_multinomial(&prs)?; - Ok(indices[index as usize] as u32) - } + fn sample_topk(&mut self, logits: &Tensor, top_k: usize, temperature: f64) -> Result { + let arg_sort = logits.arg_sort_last_dim(false)?; + let top_k_indices = arg_sort.narrow(candle::D::Minus1, 0, top_k)?; + let top_k_logits = logits.gather(&top_k_indices, D::Minus1)?; + let top_k_logits = (&top_k_logits / temperature)?; + let top_k_prs = candle_nn::ops::softmax_last_dim(&top_k_logits)?; + let top_k_prs = top_k_prs.to_vec1()?; + let index = self.sample_multinomial(&top_k_prs)?; + Ok(top_k_indices.i(index as usize)?.to_vec0::()?) } // top-k sampling samples from the k tokens with the largest probabilities. @@ -137,8 +135,12 @@ impl LogitsProcessor { } } Sampling::TopK { k, temperature } => { - let mut prs = prs(*temperature)?; - self.sample_topk(&mut prs, *k)? + if *k >= logits.dim(D::Minus1)? { + let prs = prs(*temperature)?; + self.sample_multinomial(&prs)? + } else { + self.sample_topk(&logits, *k, *temperature)? + } } Sampling::TopKThenTopP { k, p, temperature } => { let mut prs = prs(*temperature)?;