mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Softmax numerical stability. (#267)
* Softmax numerical stability. * Fix the flash-attn test.
This commit is contained in:
@ -30,16 +30,6 @@ fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> {
|
||||
Ok(mask)
|
||||
}
|
||||
|
||||
// TODO: Use a numerically stable implementation by default.
|
||||
fn softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
|
||||
let d = d.to_index(xs.shape(), "log-softmax")?;
|
||||
let max = xs.max_keepdim(d)?;
|
||||
let diff = xs.broadcast_sub(&max)?;
|
||||
let num = diff.exp()?;
|
||||
let den = num.sum_keepdim(d)?;
|
||||
num.broadcast_div(&den)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
@ -192,7 +182,7 @@ impl Attention {
|
||||
let mask_value =
|
||||
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
|
||||
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
|
||||
let attn_weights = softmax(&attn_weights, D::Minus1)?;
|
||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
let value = value.contiguous()?;
|
||||
let attn_output = if self.multi_query {
|
||||
attn_weights
|
||||
|
Reference in New Issue
Block a user