diff --git a/candle-transformers/src/models/bigcode.rs b/candle-transformers/src/models/bigcode.rs index 1e63956b..c4a2d1db 100644 --- a/candle-transformers/src/models/bigcode.rs +++ b/candle-transformers/src/models/bigcode.rs @@ -182,7 +182,7 @@ impl Attention { let mask_value = Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?; let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?; - let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; let value = value.contiguous()?; let attn_output = if self.multi_query { attn_weights diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index b2fa2860..1ef8a984 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -275,7 +275,7 @@ impl MHA { f32::NEG_INFINITY, )?, }; - let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; // output = torch.einsum('bhts,bshd->bthd', attention_drop, v) // attn_weights: b*h,t,s, v: b*h,s,d diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs index e458cf5c..f7eebb72 100644 --- a/candle-transformers/src/models/quantized_mixformer.rs +++ b/candle-transformers/src/models/quantized_mixformer.rs @@ -256,7 +256,7 @@ impl MHA { f32::NEG_INFINITY, )?, }; - let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; // output = torch.einsum('bhts,bshd->bthd', attention_drop, v) // attn_weights: b*h,t,s, v: b*h,s,d diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs index bf5797e9..398e82a7 100644 --- a/candle-transformers/src/models/quantized_t5.rs +++ b/candle-transformers/src/models/quantized_t5.rs @@ -441,7 +441,7 @@ impl T5Attention { let attn_weights = { let _enter = self.span_sm.enter(); - candle_nn::ops::softmax(&scores, D::Minus1)? + candle_nn::ops::softmax_last_dim(&scores)? }; let attn_output = attn_weights.matmul(&v)?; let attn_output = attn_output diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index bdfabf28..9b3d97b8 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -441,7 +441,7 @@ impl T5Attention { let attn_weights = { let _enter = self.span_sm.enter(); - candle_nn::ops::softmax(&scores, D::Minus1)? + candle_nn::ops::softmax_last_dim(&scores)? }; let attn_output = attn_weights.matmul(&v)?; let attn_output = attn_output