Fix for gemma MQA. (#2091)

This commit is contained in:
Laurent Mazare
2024-04-19 21:49:55 +02:00
committed by GitHub
parent 9c532aef47
commit b45c710dbf

View File

@ -227,8 +227,9 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
let value_states =
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);