mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
use softmax_last_dim (metal and cuda kernel) in llama attention layer (#2572)
This commit is contained in:
@ -341,7 +341,8 @@ impl CausalSelfAttention {
|
|||||||
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||||
masked_fill(&att, &mask, f32::NEG_INFINITY)?
|
masked_fill(&att, &mask, f32::NEG_INFINITY)?
|
||||||
};
|
};
|
||||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
|
||||||
|
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
|
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
|
||||||
};
|
};
|
||||||
|
Reference in New Issue
Block a user