Softmax numerical stability. (#267)

* Softmax numerical stability.

* Fix the flash-attn test.
This commit is contained in:
Laurent Mazare
2023-07-28 13:13:01 +01:00
committed by GitHub
parent 68eab38de6
commit 3eb2bc6d07
28 changed files with 117 additions and 188 deletions

View File

@ -309,11 +309,13 @@ impl FalconAttention {
// Only handle the case where alibi is None here, and non-flash attention.
let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
let attention_scores = attention_scores
.broadcast_add(&mask.squeeze(1)?)?
.to_dtype(DType::F32)?
.softmax(D::Minus1)?
.to_dtype(x.dtype())?;
let attention_scores = candle_nn::ops::softmax(
&attention_scores
.broadcast_add(&mask.squeeze(1)?)?
.to_dtype(DType::F32)?,
D::Minus1,
)?
.to_dtype(x.dtype())?;
let attn_output = attention_scores
.matmul(&value)?
.reshape((b_sz, self.num_heads, seq_len, head_dim))?