Use softmax-last-dim where possible. (#1057)

This commit is contained in:
Laurent Mazare
2023-10-08 13:16:42 +01:00
committed by GitHub
parent 9abeddd750
commit 783735cf22
5 changed files with 5 additions and 5 deletions

View File

@ -256,7 +256,7 @@ impl MHA {
f32::NEG_INFINITY,
)?,
};
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
// attn_weights: b*h,t,s, v: b*h,s,d