Add a simpler way to specify the dim index for some ops.

This commit is contained in:
laurent
2023-07-05 20:22:43 +01:00
parent b7388bbf71
commit 2c3d871b2e
7 changed files with 93 additions and 34 deletions

View File

@ -342,8 +342,8 @@ impl MultiHeadAttention {
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
qk = qk.broadcast_add(&mask)?
}
let w = qk.softmax(qk.rank() - 1)?;
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten(Some(2), None)?;
let w = qk.softmax(candle::D::Minus1)?;
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?;
Ok(wv)
}
}