mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 04:22:50 +00:00
Softmax numerical stability. (#267)
* Softmax numerical stability. * Fix the flash-attn test.
This commit is contained in:
@ -21,3 +21,4 @@ rayon = "1.7.0"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
candle-nn = { path = "../candle-nn", features = ["cuda"] }
|
||||
|
@ -21,7 +21,7 @@ fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result<
|
||||
let k = k.to_dtype(DType::F32)?;
|
||||
let v = v.to_dtype(DType::F32)?;
|
||||
let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
|
||||
let att = att.softmax(D::Minus1)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?;
|
||||
Ok(output)
|
||||
|
Reference in New Issue
Block a user