mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Use softmax-last-dim where possible. (#1057)
This commit is contained in:
@ -256,7 +256,7 @@ impl MHA {
|
||||
f32::NEG_INFINITY,
|
||||
)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
|
||||
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
||||
// attn_weights: b*h,t,s, v: b*h,s,d
|
||||
|
Reference in New Issue
Block a user