diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs index 882f4cf8..b393d599 100644 --- a/candle-transformers/src/models/quantized_mixformer.rs +++ b/candle-transformers/src/models/quantized_mixformer.rs @@ -199,7 +199,10 @@ impl MHA { Some((prev_k, _)) => prev_k.dim(1)?, }; // In the python implementation, a single tensor is returned with the third axis of size 3. - let (q, k, v) = self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?; + // let (q, k, v) = self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?; + let q = qkv.i((.., .., 0))?; + let k = qkv.i((.., .., 1))?; + let v = qkv.i((.., .., 2))?; let (k, v) = match &self.kv_cache { None => (k, v), Some((prev_k, prev_v)) => {