Use softmax-last-dim where possible. (#1057)

This commit is contained in:
Laurent Mazare
2023-10-08 13:16:42 +01:00
committed by GitHub
parent 9abeddd750
commit 783735cf22
5 changed files with 5 additions and 5 deletions

View File

@ -182,7 +182,7 @@ impl Attention {
let mask_value =
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let value = value.contiguous()?;
let attn_output = if self.multi_query {
attn_weights

View File

@ -275,7 +275,7 @@ impl MHA {
f32::NEG_INFINITY,
)?,
};
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
// attn_weights: b*h,t,s, v: b*h,s,d

View File

@ -256,7 +256,7 @@ impl MHA {
f32::NEG_INFINITY,
)?,
};
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
// attn_weights: b*h,t,s, v: b*h,s,d

View File

@ -441,7 +441,7 @@ impl T5Attention {
let attn_weights = {
let _enter = self.span_sm.enter();
candle_nn::ops::softmax(&scores, D::Minus1)?
candle_nn::ops::softmax_last_dim(&scores)?
};
let attn_output = attn_weights.matmul(&v)?;
let attn_output = attn_output

View File

@ -441,7 +441,7 @@ impl T5Attention {
let attn_weights = {
let _enter = self.span_sm.enter();
candle_nn::ops::softmax(&scores, D::Minus1)?
candle_nn::ops::softmax_last_dim(&scores)?
};
let attn_output = attn_weights.matmul(&v)?;
let attn_output = attn_output