mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add where_cond and properly apply the causal mask.
This commit is contained in:
@ -220,10 +220,8 @@ impl Mlp {
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let _on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
|
||||
// TODO: add an equivalent to where (or xla's select) so that we can use the following:
|
||||
// let m = mask.where_cond(&on_true, on_false)?;
|
||||
let m = on_false.clone();
|
||||
let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
@ -297,7 +295,7 @@ impl CausalSelfAttention {
|
||||
//let mask = Tensor::new(1u32, &device)?
|
||||
// .broadcast_as(&[t, t])?
|
||||
// .lower_triangle()?
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &device)?.reshape(&[1, 1, t, t])?;
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &device)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = att.softmax(att.rank() - 1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
|
Reference in New Issue
Block a user