Also avoid the mask in the llama example.

This commit is contained in:
laurent
2024-03-24 19:04:32 +01:00
parent 8c0db87992
commit cf7d7fcf2f

View File

@ -240,8 +240,12 @@ impl CausalSelfAttention {
let k = k.to_dtype(DType::F32)?;
let v = v.to_dtype(DType::F32)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let att = if seq_len == 1 {
att
} else {
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
masked_fill(&att, &mask, f32::NEG_INFINITY)?
};
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?