Use softmax-last-dim where possible. (#1057)

This commit is contained in:
Laurent Mazare
2023-10-08 13:16:42 +01:00
committed by GitHub
parent 9abeddd750
commit 783735cf22
5 changed files with 5 additions and 5 deletions

View File

@ -182,7 +182,7 @@ impl Attention {
let mask_value = let mask_value =
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?; Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?; let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let value = value.contiguous()?; let value = value.contiguous()?;
let attn_output = if self.multi_query { let attn_output = if self.multi_query {
attn_weights attn_weights

View File

@ -275,7 +275,7 @@ impl MHA {
f32::NEG_INFINITY, f32::NEG_INFINITY,
)?, )?,
}; };
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v) // output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
// attn_weights: b*h,t,s, v: b*h,s,d // attn_weights: b*h,t,s, v: b*h,s,d

View File

@ -256,7 +256,7 @@ impl MHA {
f32::NEG_INFINITY, f32::NEG_INFINITY,
)?, )?,
}; };
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v) // output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
// attn_weights: b*h,t,s, v: b*h,s,d // attn_weights: b*h,t,s, v: b*h,s,d

View File

@ -441,7 +441,7 @@ impl T5Attention {
let attn_weights = { let attn_weights = {
let _enter = self.span_sm.enter(); let _enter = self.span_sm.enter();
candle_nn::ops::softmax(&scores, D::Minus1)? candle_nn::ops::softmax_last_dim(&scores)?
}; };
let attn_output = attn_weights.matmul(&v)?; let attn_output = attn_weights.matmul(&v)?;
let attn_output = attn_output let attn_output = attn_output

View File

@ -441,7 +441,7 @@ impl T5Attention {
let attn_weights = { let attn_weights = {
let _enter = self.span_sm.enter(); let _enter = self.span_sm.enter();
candle_nn::ops::softmax(&scores, D::Minus1)? candle_nn::ops::softmax_last_dim(&scores)?
}; };
let attn_output = attn_weights.matmul(&v)?; let attn_output = attn_weights.matmul(&v)?;
let attn_output = attn_output let attn_output = attn_output