diff --git a/examples/llama/main.rs b/examples/llama/main.rs index e09f5e2f..1db15816 100644 --- a/examples/llama/main.rs +++ b/examples/llama/main.rs @@ -290,10 +290,14 @@ impl CausalSelfAttention { let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?; let device = x.device(); // TODO: If we support bool or u8 tensors, this would be better. - let mask = Tensor::new(1u32, &device)? - .broadcast_as(&[t, t])? - // TODO: .lower_triangle()? - .reshape(&[1, 1, t, t])?; + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u32::from(j <= i))) + .collect(); + // Once lower_triangle is available, use the following: + //let mask = Tensor::new(1u32, &device)? + // .broadcast_as(&[t, t])? + // .lower_triangle()? + let mask = Tensor::from_slice(&mask, (t, t), &device)?.reshape(&[1, 1, t, t])?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = att.softmax(att.rank() - 1)?; // Convert to contiguous as matmul doesn't support strided vs for now.