mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Use softmax-last-dim where possible. (#1057)
This commit is contained in:
@ -182,7 +182,7 @@ impl Attention {
|
||||
let mask_value =
|
||||
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
|
||||
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
|
||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let value = value.contiguous()?;
|
||||
let attn_output = if self.multi_query {
|
||||
attn_weights
|
||||
|
Reference in New Issue
Block a user