Fix the causal mask computation.

This commit is contained in:
laurent
2023-06-25 20:19:30 +01:00
parent 8e404eb125
commit 25bcad290e

View File

@ -290,10 +290,14 @@ impl CausalSelfAttention {
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?; let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
let device = x.device(); let device = x.device();
// TODO: If we support bool or u8 tensors, this would be better. // TODO: If we support bool or u8 tensors, this would be better.
let mask = Tensor::new(1u32, &device)? let mask: Vec<_> = (0..t)
.broadcast_as(&[t, t])? .flat_map(|i| (0..t).map(move |j| u32::from(j <= i)))
// TODO: .lower_triangle()? .collect();
.reshape(&[1, 1, t, t])?; // Once lower_triangle is available, use the following:
//let mask = Tensor::new(1u32, &device)?
// .broadcast_as(&[t, t])?
// .lower_triangle()?
let mask = Tensor::from_slice(&mask, (t, t), &device)?.reshape(&[1, 1, t, t])?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(att.rank() - 1)?; let att = att.softmax(att.rank() - 1)?;
// Convert to contiguous as matmul doesn't support strided vs for now. // Convert to contiguous as matmul doesn't support strided vs for now.