mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Use softmax-last-dim where possible. (#1057)
This commit is contained in:
@ -182,7 +182,7 @@ impl Attention {
|
||||
let mask_value =
|
||||
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
|
||||
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
|
||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let value = value.contiguous()?;
|
||||
let attn_output = if self.multi_query {
|
||||
attn_weights
|
||||
|
@ -275,7 +275,7 @@ impl MHA {
|
||||
f32::NEG_INFINITY,
|
||||
)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
|
||||
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
||||
// attn_weights: b*h,t,s, v: b*h,s,d
|
||||
|
@ -256,7 +256,7 @@ impl MHA {
|
||||
f32::NEG_INFINITY,
|
||||
)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
|
||||
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
||||
// attn_weights: b*h,t,s, v: b*h,s,d
|
||||
|
@ -441,7 +441,7 @@ impl T5Attention {
|
||||
|
||||
let attn_weights = {
|
||||
let _enter = self.span_sm.enter();
|
||||
candle_nn::ops::softmax(&scores, D::Minus1)?
|
||||
candle_nn::ops::softmax_last_dim(&scores)?
|
||||
};
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
let attn_output = attn_output
|
||||
|
@ -441,7 +441,7 @@ impl T5Attention {
|
||||
|
||||
let attn_weights = {
|
||||
let _enter = self.span_sm.enter();
|
||||
candle_nn::ops::softmax(&scores, D::Minus1)?
|
||||
candle_nn::ops::softmax_last_dim(&scores)?
|
||||
};
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
let attn_output = attn_output
|
||||
|
Reference in New Issue
Block a user