mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Softmax numerical stability. (#267)
* Softmax numerical stability. * Fix the flash-attn test.
This commit is contained in:
@ -17,7 +17,7 @@ impl LogitsProcessor {
|
||||
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
|
||||
let logits = logits.to_dtype(DType::F32)?;
|
||||
let next_token = if let Some(temperature) = self.temperature {
|
||||
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
|
||||
let prs = candle_nn::ops::softmax(&(&logits / temperature)?, D::Minus1)?;
|
||||
let prs: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
||||
distr.sample(&mut self.rng) as u32
|
||||
|
Reference in New Issue
Block a user