mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Use softmax-last-dim in the quantized example. (#848)
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use candle::{DType, Error, Result, Tensor, D};
|
||||
use candle::{DType, Error, Result, Tensor};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
|
||||
pub struct LogitsProcessor {
|
||||
@ -9,6 +9,11 @@ pub struct LogitsProcessor {
|
||||
|
||||
impl LogitsProcessor {
|
||||
pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self {
|
||||
let temperature = if temperature.map_or(true, |v| v < 1e-7) {
|
||||
None
|
||||
} else {
|
||||
temperature
|
||||
};
|
||||
Self {
|
||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||
temperature,
|
||||
@ -27,7 +32,7 @@ impl LogitsProcessor {
|
||||
Ok(next_token)
|
||||
}
|
||||
|
||||
fn sample_multi(&mut self, prs: &Vec<f32>) -> Result<u32> {
|
||||
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
|
||||
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
||||
let next_token = distr.sample(&mut self.rng) as u32;
|
||||
Ok(next_token)
|
||||
@ -51,28 +56,26 @@ impl LogitsProcessor {
|
||||
cumsum += prs[*index];
|
||||
}
|
||||
}
|
||||
|
||||
// Sample with clamped probabilities.
|
||||
let next_token = self.sample_multi(prs)?;
|
||||
Ok(next_token)
|
||||
self.sample_multinomial(prs)
|
||||
}
|
||||
|
||||
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
|
||||
let logits = logits.to_dtype(DType::F32)?;
|
||||
let temperature = self.temperature.unwrap_or(0.);
|
||||
let top_p = self.top_p.unwrap_or(1.);
|
||||
let next_token = if temperature == 0. {
|
||||
self.sample_argmax(logits)?
|
||||
} else {
|
||||
let logits = &(&logits / temperature)?;
|
||||
let prs = candle_nn::ops::softmax(logits, D::Minus1)?;
|
||||
let mut prs: Vec<f32> = prs.to_vec1()?;
|
||||
if top_p <= 0.0 || top_p >= 1.0 {
|
||||
// simply sample from the predicted probability distribution
|
||||
self.sample_multi(&prs)?
|
||||
} else {
|
||||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||
self.sample_topp(&mut prs, top_p as f32)?
|
||||
let next_token = match self.temperature {
|
||||
None => self.sample_argmax(logits)?,
|
||||
Some(temperature) => {
|
||||
let logits = &(&logits / temperature)?;
|
||||
let prs = candle_nn::ops::softmax_last_dim(logits)?;
|
||||
let mut prs: Vec<f32> = prs.to_vec1()?;
|
||||
let top_p = self.top_p.unwrap_or(1.);
|
||||
if top_p <= 0.0 || top_p >= 1.0 {
|
||||
// simply sample from the predicted probability distribution
|
||||
self.sample_multinomial(&prs)?
|
||||
} else {
|
||||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||
self.sample_topp(&mut prs, top_p as f32)?
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(next_token)
|
||||
|
@ -144,7 +144,7 @@ impl LayerWeights {
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = mask.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
|
Reference in New Issue
Block a user