diff --git a/examples/llama/main.rs b/examples/llama/main.rs index e51beefd..fa8ec46f 100644 --- a/examples/llama/main.rs +++ b/examples/llama/main.rs @@ -289,7 +289,7 @@ impl CausalSelfAttention { let device = x.device(); // TODO: If we support bool or u8 tensors, this would be better. let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u32::from(j <= i))) + .flat_map(|i| (0..t).map(move |j| u32::from(j > i))) .collect(); // Once lower_triangle is available, use the following: //let mask = Tensor::new(1u32, &device)?