GQA support in the quantized model. (#555)

* GQA support in the quantized model.

* Fix the reshaping.

* Fix the main llama model.

* Infer the proper gqa from the model kind.
This commit is contained in:
Laurent Mazare
2023-08-22 19:41:10 +01:00
committed by GitHub
parent 07067b01dc
commit f9ecc84477
2 changed files with 32 additions and 6 deletions

View File

@ -291,7 +291,7 @@ impl CausalSelfAttention {
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
Ok(x)
}
}

View File

@ -68,6 +68,7 @@ struct LayerWeights {
feed_forward_w3: QMatMul,
ffn_norm: RmsNorm,
n_head: usize,
n_kv_head: usize,
head_dim: usize,
cos: Tensor,
sin: Tensor,
@ -125,10 +126,10 @@ impl LayerWeights {
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, index_pos)?;
@ -144,7 +145,9 @@ impl LayerWeights {
};
self.kv_cache = Some((k.clone(), v.clone()));
// If we start supporting MQA, we need to repeat the k and v tensors here.
// Support for MQA, useful for 70B models.
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = mask.broadcast_as(att.shape())?;
@ -156,6 +159,20 @@ impl LayerWeights {
let y = self.attention_wo.forward(&y)?;
Ok(y)
}
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.n_head / self.n_kv_head;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
Ok(x)
}
}
}
struct ModelWeights {
@ -179,7 +196,7 @@ impl WeightMap {
}
impl ModelWeights {
fn new(mut ct: Content) -> Result<Self> {
fn new(mut ct: Content, gqa: usize) -> Result<Self> {
let cpu = &Device::Cpu;
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
@ -226,6 +243,7 @@ impl ModelWeights {
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
ffn_norm: RmsNorm::new(ffn_norm)?,
n_head: ct.hparams.n_head as usize,
n_kv_head: ct.hparams.n_head as usize / gqa,
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
cos: cos.clone(),
sin: sin.clone(),
@ -347,6 +365,10 @@ struct Args {
/// The model size to use.
#[arg(long, default_value = "7b")]
which: Which,
/// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
#[arg(long)]
gqa: Option<usize>,
}
impl Args {
@ -468,7 +490,11 @@ fn main() -> anyhow::Result<()> {
start.elapsed().as_secs_f32(),
);
println!("params: {:?}", model.hparams);
let mut model = ModelWeights::new(model)?;
let default_gqa = match args.which {
Which::L7b | Which::L13b => 1,
Which::L70b => 8,
};
let mut model = ModelWeights::new(model, args.gqa.unwrap_or(default_gqa))?;
println!("model built");
let tokenizer = args.tokenizer()?;