Avoid the attention mask where possible. (#1933)

This commit is contained in:
Laurent Mazare
2024-03-25 15:31:04 +01:00
committed by GitHub
parent cd254074f3
commit d3a8d291d5
3 changed files with 32 additions and 16 deletions

View File

@ -71,8 +71,12 @@ impl CausalSelfAttention {
let v = v.transpose(1, 2)?.contiguous()?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = if seq_len <= 1 {
att
} else {
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
masked_fill(&att, &mask, f32::NEG_INFINITY)?
};
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;