diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index 6c8c8ae4..b1a567c3 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -1,4 +1,4 @@ -use candle::{DType, Error, Result, Tensor, D}; +use candle::{DType, Error, Result, Tensor}; use rand::{distributions::Distribution, SeedableRng}; pub struct LogitsProcessor { @@ -9,6 +9,11 @@ pub struct LogitsProcessor { impl LogitsProcessor { pub fn new(seed: u64, temperature: Option, top_p: Option) -> Self { + let temperature = if temperature.map_or(true, |v| v < 1e-7) { + None + } else { + temperature + }; Self { rng: rand::rngs::StdRng::seed_from_u64(seed), temperature, @@ -27,7 +32,7 @@ impl LogitsProcessor { Ok(next_token) } - fn sample_multi(&mut self, prs: &Vec) -> Result { + fn sample_multinomial(&mut self, prs: &Vec) -> Result { let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?; let next_token = distr.sample(&mut self.rng) as u32; Ok(next_token) @@ -51,28 +56,26 @@ impl LogitsProcessor { cumsum += prs[*index]; } } - // Sample with clamped probabilities. - let next_token = self.sample_multi(prs)?; - Ok(next_token) + self.sample_multinomial(prs) } pub fn sample(&mut self, logits: &Tensor) -> Result { let logits = logits.to_dtype(DType::F32)?; - let temperature = self.temperature.unwrap_or(0.); - let top_p = self.top_p.unwrap_or(1.); - let next_token = if temperature == 0. { - self.sample_argmax(logits)? - } else { - let logits = &(&logits / temperature)?; - let prs = candle_nn::ops::softmax(logits, D::Minus1)?; - let mut prs: Vec = prs.to_vec1()?; - if top_p <= 0.0 || top_p >= 1.0 { - // simply sample from the predicted probability distribution - self.sample_multi(&prs)? - } else { - // top-p (nucleus) sampling, clamping the least likely tokens to zero - self.sample_topp(&mut prs, top_p as f32)? + let next_token = match self.temperature { + None => self.sample_argmax(logits)?, + Some(temperature) => { + let logits = &(&logits / temperature)?; + let prs = candle_nn::ops::softmax_last_dim(logits)?; + let mut prs: Vec = prs.to_vec1()?; + let top_p = self.top_p.unwrap_or(1.); + if top_p <= 0.0 || top_p >= 1.0 { + // simply sample from the predicted probability distribution + self.sample_multinomial(&prs)? + } else { + // top-p (nucleus) sampling, clamping the least likely tokens to zero + self.sample_topp(&mut prs, top_p as f32)? + } } }; Ok(next_token) diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index da0bd0b0..2988b0fb 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -144,7 +144,7 @@ impl LayerWeights { let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let mask = mask.broadcast_as(att.shape())?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = candle_nn::ops::softmax(&att, D::Minus1)?; + let att = candle_nn::ops::softmax_last_dim(&att)?; // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;