Softmax numerical stability. (#267)

* Softmax numerical stability.

* Fix the flash-attn test.
This commit is contained in:
Laurent Mazare
2023-07-28 13:13:01 +01:00
committed by GitHub
parent 68eab38de6
commit 3eb2bc6d07
28 changed files with 117 additions and 188 deletions

View File

@ -2,7 +2,7 @@
// back when using RUST_LIB_BACKTRACE=1.
use anyhow::Result;
use candle::{Device, Tensor};
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
use serde::Deserialize;
// The names in comments correspond to the original implementation:
@ -154,7 +154,7 @@ impl MultiHeadAttention {
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
qk = qk.broadcast_add(&mask)?
}
let w = qk.softmax(candle::D::Minus1)?;
let w = softmax(&qk, candle::D::Minus1)?;
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?;
Ok(wv)
}