mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Fix the causal mask computation.
This commit is contained in:
@ -290,10 +290,14 @@ impl CausalSelfAttention {
|
||||
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
|
||||
let device = x.device();
|
||||
// TODO: If we support bool or u8 tensors, this would be better.
|
||||
let mask = Tensor::new(1u32, &device)?
|
||||
.broadcast_as(&[t, t])?
|
||||
// TODO: .lower_triangle()?
|
||||
.reshape(&[1, 1, t, t])?;
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u32::from(j <= i)))
|
||||
.collect();
|
||||
// Once lower_triangle is available, use the following:
|
||||
//let mask = Tensor::new(1u32, &device)?
|
||||
// .broadcast_as(&[t, t])?
|
||||
// .lower_triangle()?
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &device)?.reshape(&[1, 1, t, t])?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = att.softmax(att.rank() - 1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
|
Reference in New Issue
Block a user