Use softmax-last-dim where possible. (#1057)

This commit is contained in:
Laurent Mazare
2023-10-08 13:16:42 +01:00
committed by GitHub
parent 9abeddd750
commit 783735cf22
5 changed files with 5 additions and 5 deletions

View File

@ -182,7 +182,7 @@ impl Attention {
let mask_value =
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let value = value.contiguous()?;
let attn_output = if self.multi_query {
attn_weights