Use softmax-last-dim in the quantized example. (#848)

This commit is contained in:
Laurent Mazare
2023-09-14 18:29:24 +02:00
committed by GitHub
parent a0c6d5548c
commit 0a647875ec
2 changed files with 23 additions and 20 deletions

View File

@ -1,4 +1,4 @@
use candle::{DType, Error, Result, Tensor, D};
use candle::{DType, Error, Result, Tensor};
use rand::{distributions::Distribution, SeedableRng};
pub struct LogitsProcessor {
@ -9,6 +9,11 @@ pub struct LogitsProcessor {
impl LogitsProcessor {
pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self {
let temperature = if temperature.map_or(true, |v| v < 1e-7) {
None
} else {
temperature
};
Self {
rng: rand::rngs::StdRng::seed_from_u64(seed),
temperature,
@ -27,7 +32,7 @@ impl LogitsProcessor {
Ok(next_token)
}
fn sample_multi(&mut self, prs: &Vec<f32>) -> Result<u32> {
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
let next_token = distr.sample(&mut self.rng) as u32;
Ok(next_token)
@ -51,28 +56,26 @@ impl LogitsProcessor {
cumsum += prs[*index];
}
}
// Sample with clamped probabilities.
let next_token = self.sample_multi(prs)?;
Ok(next_token)
self.sample_multinomial(prs)
}
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?;
let temperature = self.temperature.unwrap_or(0.);
let top_p = self.top_p.unwrap_or(1.);
let next_token = if temperature == 0. {
self.sample_argmax(logits)?
} else {
let logits = &(&logits / temperature)?;
let prs = candle_nn::ops::softmax(logits, D::Minus1)?;
let mut prs: Vec<f32> = prs.to_vec1()?;
if top_p <= 0.0 || top_p >= 1.0 {
// simply sample from the predicted probability distribution
self.sample_multi(&prs)?
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
self.sample_topp(&mut prs, top_p as f32)?
let next_token = match self.temperature {
None => self.sample_argmax(logits)?,
Some(temperature) => {
let logits = &(&logits / temperature)?;
let prs = candle_nn::ops::softmax_last_dim(logits)?;
let mut prs: Vec<f32> = prs.to_vec1()?;
let top_p = self.top_p.unwrap_or(1.);
if top_p <= 0.0 || top_p >= 1.0 {
// simply sample from the predicted probability distribution
self.sample_multinomial(&prs)?
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
self.sample_topp(&mut prs, top_p as f32)?
}
}
};
Ok(next_token)

View File

@ -144,7 +144,7 @@ impl LayerWeights {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = mask.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;