Add where_cond and properly apply the causal mask.

This commit is contained in:
laurent
2023-06-25 21:08:03 +01:00
parent 25bcad290e
commit 117f014b55
8 changed files with 168 additions and 24 deletions

View File

@ -220,10 +220,8 @@ impl Mlp {
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let _on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
// TODO: add an equivalent to where (or xla's select) so that we can use the following:
// let m = mask.where_cond(&on_true, on_false)?;
let m = on_false.clone();
let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
@ -297,7 +295,7 @@ impl CausalSelfAttention {
//let mask = Tensor::new(1u32, &device)?
// .broadcast_as(&[t, t])?
// .lower_triangle()?
let mask = Tensor::from_slice(&mask, (t, t), &device)?.reshape(&[1, 1, t, t])?;
let mask = Tensor::from_slice(&mask, (t, t), &device)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(att.rank() - 1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.