GQA support in the quantized model. (#555)

* GQA support in the quantized model.

* Fix the reshaping.

* Fix the main llama model.

* Infer the proper gqa from the model kind.
This commit is contained in:
Laurent Mazare
2023-08-22 19:41:10 +01:00
committed by GitHub
parent 07067b01dc
commit f9ecc84477
2 changed files with 32 additions and 6 deletions

View File

@ -291,7 +291,7 @@ impl CausalSelfAttention {
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
Ok(x)
}
}