Fix the causal mask computation.

This commit is contained in:
laurent
2023-06-25 20:19:30 +01:00
parent 8e404eb125
commit 25bcad290e

View File

@ -290,10 +290,14 @@ impl CausalSelfAttention {
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
let device = x.device();
// TODO: If we support bool or u8 tensors, this would be better.
let mask = Tensor::new(1u32, &device)?
.broadcast_as(&[t, t])?
// TODO: .lower_triangle()?
.reshape(&[1, 1, t, t])?;
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u32::from(j <= i)))
.collect();
// Once lower_triangle is available, use the following:
//let mask = Tensor::new(1u32, &device)?
// .broadcast_as(&[t, t])?
// .lower_triangle()?
let mask = Tensor::from_slice(&mask, (t, t), &device)?.reshape(&[1, 1, t, t])?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(att.rank() - 1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.