Softmax numerical stability. (#267)

* Softmax numerical stability.

* Fix the flash-attn test.
This commit is contained in:
Laurent Mazare
2023-07-28 13:13:01 +01:00
committed by GitHub
parent 68eab38de6
commit 3eb2bc6d07
28 changed files with 117 additions and 188 deletions

View File

@ -11,7 +11,7 @@ extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle::{safetensors::Load, DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_nn::{ops::softmax, VarBuilder};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
@ -120,9 +120,7 @@ impl Decoder {
// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
no_speech_prob = logits
.get(0)?
.softmax(0)?
no_speech_prob = softmax(&logits.get(0)?, 0)?
.get(NO_SPEECH_TOKEN as usize)?
.to_scalar::<f32>()? as f64;
}
@ -132,7 +130,7 @@ impl Decoder {
.get(seq_len - 1)?
.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
let prs = (&logits / t)?.softmax(0)?;
let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(&mut self.rng) as u32
@ -146,8 +144,7 @@ impl Decoder {
.unwrap()
};
tokens.push(next_token);
let prob = logits
.softmax(candle::D::Minus1)?
let prob = softmax(&logits, candle::D::Minus1)?
.get(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {

View File

@ -2,7 +2,7 @@
// back when using RUST_LIB_BACKTRACE=1.
use anyhow::Result;
use candle::{Device, Tensor};
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
use serde::Deserialize;
// The names in comments correspond to the original implementation:
@ -154,7 +154,7 @@ impl MultiHeadAttention {
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
qk = qk.broadcast_add(&mask)?
}
let w = qk.softmax(candle::D::Minus1)?;
let w = softmax(&qk, candle::D::Minus1)?;
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?;
Ok(wv)
}