From b45c710dbf61445751ae56052131ccd40a25b6b8 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 19 Apr 2024 21:49:55 +0200 Subject: [PATCH] Fix for gemma MQA. (#2091) --- candle-transformers/src/models/gemma.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index 58b5f1e1..3bde88b4 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -227,8 +227,9 @@ impl Attention { }; self.kv_cache = Some((key_states.clone(), value_states.clone())); - let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?; - let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?; + let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?; + let value_states = + crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; let attn_output = { let scale = 1f64 / f64::sqrt(self.head_dim as f64);