mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Fix the causal mask computation.
This commit is contained in:
@ -290,10 +290,14 @@ impl CausalSelfAttention {
|
|||||||
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
|
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
|
||||||
let device = x.device();
|
let device = x.device();
|
||||||
// TODO: If we support bool or u8 tensors, this would be better.
|
// TODO: If we support bool or u8 tensors, this would be better.
|
||||||
let mask = Tensor::new(1u32, &device)?
|
let mask: Vec<_> = (0..t)
|
||||||
.broadcast_as(&[t, t])?
|
.flat_map(|i| (0..t).map(move |j| u32::from(j <= i)))
|
||||||
// TODO: .lower_triangle()?
|
.collect();
|
||||||
.reshape(&[1, 1, t, t])?;
|
// Once lower_triangle is available, use the following:
|
||||||
|
//let mask = Tensor::new(1u32, &device)?
|
||||||
|
// .broadcast_as(&[t, t])?
|
||||||
|
// .lower_triangle()?
|
||||||
|
let mask = Tensor::from_slice(&mask, (t, t), &device)?.reshape(&[1, 1, t, t])?;
|
||||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||||
let att = att.softmax(att.rank() - 1)?;
|
let att = att.softmax(att.rank() - 1)?;
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
|
Reference in New Issue
Block a user