mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Avoid the attention mask where possible. (#1933)
This commit is contained in:
@ -194,8 +194,12 @@ impl CausalSelfAttention {
|
||||
let v = v.transpose(1, 2)?.contiguous()?;
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = if seq_len <= 1 {
|
||||
att
|
||||
} else {
|
||||
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
masked_fill(&att, &mask, f32::NEG_INFINITY)?
|
||||
};
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
|
Reference in New Issue
Block a user