Softmax numerical stability. (#267)

* Softmax numerical stability.

* Fix the flash-attn test.
This commit is contained in:
Laurent Mazare
2023-07-28 13:13:01 +01:00
committed by GitHub
parent 68eab38de6
commit 3eb2bc6d07
28 changed files with 117 additions and 188 deletions

View File

@ -30,16 +30,6 @@ fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> {
Ok(mask)
}
// TODO: Use a numerically stable implementation by default.
fn softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
let d = d.to_index(xs.shape(), "log-softmax")?;
let max = xs.max_keepdim(d)?;
let diff = xs.broadcast_sub(&max)?;
let num = diff.exp()?;
let den = num.sum_keepdim(d)?;
num.broadcast_div(&den)
}
#[derive(Debug)]
pub struct Config {
pub vocab_size: usize,
@ -192,7 +182,7 @@ impl Attention {
let mask_value =
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
let attn_weights = softmax(&attn_weights, D::Minus1)?;
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
let value = value.contiguous()?;
let attn_output = if self.multi_query {
attn_weights