GQA support in the quantized model. (#555)

* GQA support in the quantized model.

* Fix the reshaping.

* Fix the main llama model.

* Infer the proper gqa from the model kind.
This commit is contained in:
Laurent Mazare
2023-08-22 19:41:10 +01:00
committed by GitHub
parent 07067b01dc
commit f9ecc84477
2 changed files with 32 additions and 6 deletions

View File

@ -291,7 +291,7 @@ impl CausalSelfAttention {
let x = x let x = x
.unsqueeze(2)? .unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?; .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
Ok(x) Ok(x)
} }
} }

View File

@ -68,6 +68,7 @@ struct LayerWeights {
feed_forward_w3: QMatMul, feed_forward_w3: QMatMul,
ffn_norm: RmsNorm, ffn_norm: RmsNorm,
n_head: usize, n_head: usize,
n_kv_head: usize,
head_dim: usize, head_dim: usize,
cos: Tensor, cos: Tensor,
sin: Tensor, sin: Tensor,
@ -125,10 +126,10 @@ impl LayerWeights {
.reshape((b_sz, seq_len, self.n_head, self.head_dim))? .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.transpose(1, 2)?; .transpose(1, 2)?;
let k = k let k = k
.reshape((b_sz, seq_len, self.n_head, self.head_dim))? .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?; .transpose(1, 2)?;
let v = v let v = v
.reshape((b_sz, seq_len, self.n_head, self.head_dim))? .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?; .transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, index_pos)?; let q = self.apply_rotary_emb(&q, index_pos)?;
@ -144,7 +145,9 @@ impl LayerWeights {
}; };
self.kv_cache = Some((k.clone(), v.clone())); self.kv_cache = Some((k.clone(), v.clone()));
// If we start supporting MQA, we need to repeat the k and v tensors here. // Support for MQA, useful for 70B models.
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = mask.broadcast_as(att.shape())?; let mask = mask.broadcast_as(att.shape())?;
@ -156,6 +159,20 @@ impl LayerWeights {
let y = self.attention_wo.forward(&y)?; let y = self.attention_wo.forward(&y)?;
Ok(y) Ok(y)
} }
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.n_head / self.n_kv_head;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
Ok(x)
}
}
} }
struct ModelWeights { struct ModelWeights {
@ -179,7 +196,7 @@ impl WeightMap {
} }
impl ModelWeights { impl ModelWeights {
fn new(mut ct: Content) -> Result<Self> { fn new(mut ct: Content, gqa: usize) -> Result<Self> {
let cpu = &Device::Cpu; let cpu = &Device::Cpu;
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
@ -226,6 +243,7 @@ impl ModelWeights {
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
ffn_norm: RmsNorm::new(ffn_norm)?, ffn_norm: RmsNorm::new(ffn_norm)?,
n_head: ct.hparams.n_head as usize, n_head: ct.hparams.n_head as usize,
n_kv_head: ct.hparams.n_head as usize / gqa,
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
cos: cos.clone(), cos: cos.clone(),
sin: sin.clone(), sin: sin.clone(),
@ -347,6 +365,10 @@ struct Args {
/// The model size to use. /// The model size to use.
#[arg(long, default_value = "7b")] #[arg(long, default_value = "7b")]
which: Which, which: Which,
/// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
#[arg(long)]
gqa: Option<usize>,
} }
impl Args { impl Args {
@ -468,7 +490,11 @@ fn main() -> anyhow::Result<()> {
start.elapsed().as_secs_f32(), start.elapsed().as_secs_f32(),
); );
println!("params: {:?}", model.hparams); println!("params: {:?}", model.hparams);
let mut model = ModelWeights::new(model)?; let default_gqa = match args.which {
Which::L7b | Which::L13b => 1,
Which::L70b => 8,
};
let mut model = ModelWeights::new(model, args.gqa.unwrap_or(default_gqa))?;
println!("model built"); println!("model built");
let tokenizer = args.tokenizer()?; let tokenizer = args.tokenizer()?;