Use softmax-last-dim in the quantized example. (#848)

This commit is contained in:
Laurent Mazare
2023-09-14 18:29:24 +02:00
committed by GitHub
parent a0c6d5548c
commit 0a647875ec
2 changed files with 23 additions and 20 deletions

View File

@ -1,4 +1,4 @@
use candle::{DType, Error, Result, Tensor, D}; use candle::{DType, Error, Result, Tensor};
use rand::{distributions::Distribution, SeedableRng}; use rand::{distributions::Distribution, SeedableRng};
pub struct LogitsProcessor { pub struct LogitsProcessor {
@ -9,6 +9,11 @@ pub struct LogitsProcessor {
impl LogitsProcessor { impl LogitsProcessor {
pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self { pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self {
let temperature = if temperature.map_or(true, |v| v < 1e-7) {
None
} else {
temperature
};
Self { Self {
rng: rand::rngs::StdRng::seed_from_u64(seed), rng: rand::rngs::StdRng::seed_from_u64(seed),
temperature, temperature,
@ -27,7 +32,7 @@ impl LogitsProcessor {
Ok(next_token) Ok(next_token)
} }
fn sample_multi(&mut self, prs: &Vec<f32>) -> Result<u32> { fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?; let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
let next_token = distr.sample(&mut self.rng) as u32; let next_token = distr.sample(&mut self.rng) as u32;
Ok(next_token) Ok(next_token)
@ -51,29 +56,27 @@ impl LogitsProcessor {
cumsum += prs[*index]; cumsum += prs[*index];
} }
} }
// Sample with clamped probabilities. // Sample with clamped probabilities.
let next_token = self.sample_multi(prs)?; self.sample_multinomial(prs)
Ok(next_token)
} }
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> { pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?; let logits = logits.to_dtype(DType::F32)?;
let temperature = self.temperature.unwrap_or(0.); let next_token = match self.temperature {
let top_p = self.top_p.unwrap_or(1.); None => self.sample_argmax(logits)?,
let next_token = if temperature == 0. { Some(temperature) => {
self.sample_argmax(logits)?
} else {
let logits = &(&logits / temperature)?; let logits = &(&logits / temperature)?;
let prs = candle_nn::ops::softmax(logits, D::Minus1)?; let prs = candle_nn::ops::softmax_last_dim(logits)?;
let mut prs: Vec<f32> = prs.to_vec1()?; let mut prs: Vec<f32> = prs.to_vec1()?;
let top_p = self.top_p.unwrap_or(1.);
if top_p <= 0.0 || top_p >= 1.0 { if top_p <= 0.0 || top_p >= 1.0 {
// simply sample from the predicted probability distribution // simply sample from the predicted probability distribution
self.sample_multi(&prs)? self.sample_multinomial(&prs)?
} else { } else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero // top-p (nucleus) sampling, clamping the least likely tokens to zero
self.sample_topp(&mut prs, top_p as f32)? self.sample_topp(&mut prs, top_p as f32)?
} }
}
}; };
Ok(next_token) Ok(next_token)
} }

View File

@ -144,7 +144,7 @@ impl LayerWeights {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = mask.broadcast_as(att.shape())?; let mask = mask.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?; let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now. // Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?; let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;