Softmax numerical stability. (#267)

* Softmax numerical stability.

* Fix the flash-attn test.
This commit is contained in:
Laurent Mazare
2023-07-28 13:13:01 +01:00
committed by GitHub
parent 68eab38de6
commit 3eb2bc6d07
28 changed files with 117 additions and 188 deletions

View File

@ -333,7 +333,7 @@ impl BertSelfAttention {
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
let attention_probs = {
let _enter_sm = self.span_softmax.enter();
attention_scores.softmax(candle::D::Minus1)?
candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
};
let attention_probs = self.dropout.forward(&attention_probs)?;