diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index 58b5f1e1..3bde88b4 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -227,8 +227,9 @@ impl Attention { }; self.kv_cache = Some((key_states.clone(), value_states.clone())); - let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?; - let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?; + let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?; + let value_states = + crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; let attn_output = { let scale = 1f64 / f64::sqrt(self.head_dim as f64);