diff --git a/candle-transformers/src/models/mimi/transformer.rs b/candle-transformers/src/models/mimi/transformer.rs index 6915d460..8a59606e 100644 --- a/candle-transformers/src/models/mimi/transformer.rs +++ b/candle-transformers/src/models/mimi/transformer.rs @@ -216,16 +216,6 @@ impl StreamingMultiheadAttention { let pre_ws = match mask { None => pre_ws, Some(mask) => { - // This is a bit cumbersome and slightly incorrect: when providing a new slice - // the kv cache will have a slice offset rather than offset + t. In the mimi - // context of an offset of 250, this would not make much difference though. - let mask_len = mask.dim(D::Minus1)?; - let pre_ws_len = pre_ws.dim(D::Minus1)?; - let mask = if pre_ws_len < mask_len { - mask.narrow(D::Minus1, mask_len - pre_ws_len, pre_ws_len)? - } else { - mask.clone() - }; let mask = mask.broadcast_left((b, self.num_heads))?; let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?; mask.where_cond(&neg_inf, &pre_ws)? @@ -639,18 +629,22 @@ impl StreamingTransformer { pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result { let (_b, t, c) = xs.dims3()?; - // We will extract at most "context" from the kv_cache. - // Note that the mask will discard the values that are before context. let pos = self.layers[0] .self_attn .kv_cache .k_cache() - .current_seq_len() - .min(self.context); + .current_seq_len(); let mask = if t == 1 { None } else { - Some(get_mask(t, pos + t, self.context, xs.device())?) + let cache_out_len = if t < self.context { + (pos + t).min(self.context) + } else { + t + }; + // TODO: this is wrong, the mask depends on the kv-cache offset because of its rotating + // nature. + Some(get_mask(t, cache_out_len, self.context, xs.device())?) }; let mut xs = match self.positional_embedding { PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),