mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Softmax numerical stability. (#267)
* Softmax numerical stability. * Fix the flash-attn test.
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
use crate::model::{Config, Whisper};
|
||||
use anyhow::Error as E;
|
||||
use candle::{safetensors::Load, DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokenizers::Tokenizer;
|
||||
@ -127,9 +127,7 @@ impl Decoder {
|
||||
// Extract the no speech probability on the first iteration by looking at the first
|
||||
// token logits and the probability for the according token.
|
||||
if i == 0 {
|
||||
no_speech_prob = logits
|
||||
.get(0)?
|
||||
.softmax(0)?
|
||||
no_speech_prob = softmax(&logits.get(0)?, 0)?
|
||||
.get(NO_SPEECH_TOKEN as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
}
|
||||
@ -139,7 +137,7 @@ impl Decoder {
|
||||
.get(seq_len - 1)?
|
||||
.broadcast_add(&self.suppress_tokens)?;
|
||||
let next_token = if t > 0f64 {
|
||||
let prs = (&logits / t)?.softmax(0)?;
|
||||
let prs = softmax(&(&logits / t)?, 0)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
distr.sample(rng) as u32
|
||||
@ -153,8 +151,7 @@ impl Decoder {
|
||||
.unwrap()
|
||||
};
|
||||
tokens.push(next_token);
|
||||
let prob = logits
|
||||
.softmax(candle::D::Minus1)?
|
||||
let prob = softmax(&logits, candle::D::Minus1)?
|
||||
.get(next_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
|
||||
|
Reference in New Issue
Block a user