Mimi streaming fixes.

This commit is contained in:
Laurent
2024-09-22 14:41:17 +02:00
parent c79bf421c7
commit 3277844fd9
2 changed files with 14 additions and 1 deletions

View File

@ -216,6 +216,16 @@ impl StreamingMultiheadAttention {
let pre_ws = match mask {
None => pre_ws,
Some(mask) => {
// This is a bit cumbersome and slightly incorrect: when providing a new slice
// the kv cache will have a slice offset rather than offset + t. In the mimi
// context of an offset of 250, this would not make much difference though.
let mask_len = mask.dim(D::Minus1)?;
let pre_ws_len = pre_ws.dim(D::Minus1)?;
let mask = if pre_ws_len < mask_len {
mask.narrow(D::Minus1, mask_len - pre_ws_len, pre_ws_len)?
} else {
mask.clone()
};
let mask = mask.broadcast_left((b, self.num_heads))?;
let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
mask.where_cond(&neg_inf, &pre_ws)?