mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Use softmax-last-dim in the quantized example. (#848)
This commit is contained in:
@ -1,4 +1,4 @@
|
|||||||
use candle::{DType, Error, Result, Tensor, D};
|
use candle::{DType, Error, Result, Tensor};
|
||||||
use rand::{distributions::Distribution, SeedableRng};
|
use rand::{distributions::Distribution, SeedableRng};
|
||||||
|
|
||||||
pub struct LogitsProcessor {
|
pub struct LogitsProcessor {
|
||||||
@ -9,6 +9,11 @@ pub struct LogitsProcessor {
|
|||||||
|
|
||||||
impl LogitsProcessor {
|
impl LogitsProcessor {
|
||||||
pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self {
|
pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self {
|
||||||
|
let temperature = if temperature.map_or(true, |v| v < 1e-7) {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
temperature
|
||||||
|
};
|
||||||
Self {
|
Self {
|
||||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||||
temperature,
|
temperature,
|
||||||
@ -27,7 +32,7 @@ impl LogitsProcessor {
|
|||||||
Ok(next_token)
|
Ok(next_token)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample_multi(&mut self, prs: &Vec<f32>) -> Result<u32> {
|
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
|
||||||
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
||||||
let next_token = distr.sample(&mut self.rng) as u32;
|
let next_token = distr.sample(&mut self.rng) as u32;
|
||||||
Ok(next_token)
|
Ok(next_token)
|
||||||
@ -51,28 +56,26 @@ impl LogitsProcessor {
|
|||||||
cumsum += prs[*index];
|
cumsum += prs[*index];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sample with clamped probabilities.
|
// Sample with clamped probabilities.
|
||||||
let next_token = self.sample_multi(prs)?;
|
self.sample_multinomial(prs)
|
||||||
Ok(next_token)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
|
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
|
||||||
let logits = logits.to_dtype(DType::F32)?;
|
let logits = logits.to_dtype(DType::F32)?;
|
||||||
let temperature = self.temperature.unwrap_or(0.);
|
let next_token = match self.temperature {
|
||||||
let top_p = self.top_p.unwrap_or(1.);
|
None => self.sample_argmax(logits)?,
|
||||||
let next_token = if temperature == 0. {
|
Some(temperature) => {
|
||||||
self.sample_argmax(logits)?
|
let logits = &(&logits / temperature)?;
|
||||||
} else {
|
let prs = candle_nn::ops::softmax_last_dim(logits)?;
|
||||||
let logits = &(&logits / temperature)?;
|
let mut prs: Vec<f32> = prs.to_vec1()?;
|
||||||
let prs = candle_nn::ops::softmax(logits, D::Minus1)?;
|
let top_p = self.top_p.unwrap_or(1.);
|
||||||
let mut prs: Vec<f32> = prs.to_vec1()?;
|
if top_p <= 0.0 || top_p >= 1.0 {
|
||||||
if top_p <= 0.0 || top_p >= 1.0 {
|
// simply sample from the predicted probability distribution
|
||||||
// simply sample from the predicted probability distribution
|
self.sample_multinomial(&prs)?
|
||||||
self.sample_multi(&prs)?
|
} else {
|
||||||
} else {
|
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
self.sample_topp(&mut prs, top_p as f32)?
|
||||||
self.sample_topp(&mut prs, top_p as f32)?
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
Ok(next_token)
|
Ok(next_token)
|
||||||
|
@ -144,7 +144,7 @@ impl LayerWeights {
|
|||||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||||
let mask = mask.broadcast_as(att.shape())?;
|
let mask = mask.broadcast_as(att.shape())?;
|
||||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
let y = att.matmul(&v.contiguous()?)?;
|
let y = att.matmul(&v.contiguous()?)?;
|
||||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||||
|
Reference in New Issue
Block a user