use candle::{DType, Error, Result, Tensor}; use rand::{distributions::Distribution, SeedableRng}; pub struct LogitsProcessor { rng: rand::rngs::StdRng, temperature: Option, top_p: Option, } impl LogitsProcessor { pub fn new(seed: u64, temperature: Option, top_p: Option) -> Self { let temperature = if temperature.map_or(true, |v| v < 1e-7) { None } else { temperature }; Self { rng: rand::rngs::StdRng::seed_from_u64(seed), temperature, top_p, } } fn sample_argmax(&mut self, logits: Tensor) -> Result { let logits_v: Vec = logits.to_vec1()?; let next_token = logits_v .iter() .enumerate() .max_by(|(_, u), (_, v)| u.total_cmp(v)) .map(|(i, _)| i as u32) .unwrap(); Ok(next_token) } fn sample_multinomial(&mut self, prs: &Vec) -> Result { let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?; let next_token = distr.sample(&mut self.rng) as u32; Ok(next_token) } fn sample_topp(&mut self, prs: &mut Vec, top_p: f32) -> Result { // top-p sampling (or "nucleus sampling") samples from the smallest set of // tokens that exceed probability top_p. This way we never sample tokens that // have very low probabilities and are less likely to go "off the rails". let mut argsort_indices = (0..prs.len()).collect::>(); // Sort by descending probability. argsort_indices.sort_by(|&i, &j| prs[j].partial_cmp(&prs[i]).unwrap()); // Clamp smaller probabilities to zero. let mut cumsum = 0.; for index in &argsort_indices { if cumsum >= top_p { prs[*index] = 0.0; } else { cumsum += prs[*index]; } } // Sample with clamped probabilities. self.sample_multinomial(prs) } pub fn sample(&mut self, logits: &Tensor) -> Result { let logits = logits.to_dtype(DType::F32)?; let next_token = match self.temperature { None => self.sample_argmax(logits)?, Some(temperature) => { let logits = &(&logits / temperature)?; let prs = candle_nn::ops::softmax_last_dim(logits)?; let mut prs: Vec = prs.to_vec1()?; let top_p = self.top_p.unwrap_or(1.); if top_p <= 0.0 || top_p >= 1.0 { // simply sample from the predicted probability distribution self.sample_multinomial(&prs)? } else { // top-p (nucleus) sampling, clamping the least likely tokens to zero self.sample_topp(&mut prs, top_p as f32)? } } }; Ok(next_token) } }