diff --git a/candle-examples/examples/mimi/main.rs b/candle-examples/examples/mimi/main.rs index 788b3fd9..0d9948b2 100644 --- a/candle-examples/examples/mimi/main.rs +++ b/candle-examples/examples/mimi/main.rs @@ -126,7 +126,10 @@ fn main() -> Result<()> { for chunk_start in (0..seq_len).step_by(chunk_size) { let chunk_len = usize::min(chunk_size, seq_len - chunk_start); let codes = codes.narrow(candle::D::Minus1, chunk_start, chunk_len)?; - pcm_chunks.push(model.decode(&codes)?) + let pcm = model.decode_step(&codes.into())?; + if let Some(pcm) = pcm.as_option() { + pcm_chunks.push(pcm.clone()) + } } Tensor::cat(&pcm_chunks, candle::D::Minus1)? } diff --git a/candle-transformers/src/models/mimi/transformer.rs b/candle-transformers/src/models/mimi/transformer.rs index 0fa70792..6915d460 100644 --- a/candle-transformers/src/models/mimi/transformer.rs +++ b/candle-transformers/src/models/mimi/transformer.rs @@ -216,6 +216,16 @@ impl StreamingMultiheadAttention { let pre_ws = match mask { None => pre_ws, Some(mask) => { + // This is a bit cumbersome and slightly incorrect: when providing a new slice + // the kv cache will have a slice offset rather than offset + t. In the mimi + // context of an offset of 250, this would not make much difference though. + let mask_len = mask.dim(D::Minus1)?; + let pre_ws_len = pre_ws.dim(D::Minus1)?; + let mask = if pre_ws_len < mask_len { + mask.narrow(D::Minus1, mask_len - pre_ws_len, pre_ws_len)? + } else { + mask.clone() + }; let mask = mask.broadcast_left((b, self.num_heads))?; let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?; mask.where_cond(&neg_inf, &pre_ws)?