mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Softmax numerical stability. (#267)
* Softmax numerical stability. * Fix the flash-attn test.
This commit is contained in:
@ -2,7 +2,7 @@
|
||||
// back when using RUST_LIB_BACKTRACE=1.
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
|
||||
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
// The names in comments correspond to the original implementation:
|
||||
@ -154,7 +154,7 @@ impl MultiHeadAttention {
|
||||
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
|
||||
qk = qk.broadcast_add(&mask)?
|
||||
}
|
||||
let w = qk.softmax(candle::D::Minus1)?;
|
||||
let w = softmax(&qk, candle::D::Minus1)?;
|
||||
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?;
|
||||
Ok(wv)
|
||||
}
|
||||
|
Reference in New Issue
Block a user