mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Softmax numerical stability. (#267)
* Softmax numerical stability. * Fix the flash-attn test.
This commit is contained in:
@ -187,7 +187,7 @@ impl MusicgenAttention {
|
||||
let attn_weights = attn_weights
|
||||
.reshape((b_sz, self.num_heads, tgt_len, src_len))?
|
||||
.broadcast_add(attention_mask)?;
|
||||
let attn_weights = attn_weights.softmax(D::Minus1)?;
|
||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
// TODO: layer_head_mask?
|
||||
let attn_output = attn_weights
|
||||
.matmul(&value_states)?
|
||||
|
@ -223,7 +223,7 @@ impl T5Attention {
|
||||
.transpose(1, 2)?;
|
||||
let scores = q.matmul(&k.t()?)?;
|
||||
// TODO: position_bias_masked
|
||||
let attn_weights = scores.softmax(D::Minus1)?;
|
||||
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
let attn_output = self.o.forward(&attn_output)?;
|
||||
Ok(attn_output)
|
||||
|
Reference in New Issue
Block a user