mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
GQA support in the quantized model. (#555)
* GQA support in the quantized model. * Fix the reshaping. * Fix the main llama model. * Infer the proper gqa from the model kind.
This commit is contained in:
@ -291,7 +291,7 @@ impl CausalSelfAttention {
|
||||
let x = x
|
||||
.unsqueeze(2)?
|
||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
|
||||
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
@ -68,6 +68,7 @@ struct LayerWeights {
|
||||
feed_forward_w3: QMatMul,
|
||||
ffn_norm: RmsNorm,
|
||||
n_head: usize,
|
||||
n_kv_head: usize,
|
||||
head_dim: usize,
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
@ -125,10 +126,10 @@ impl LayerWeights {
|
||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
@ -144,7 +145,9 @@ impl LayerWeights {
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
|
||||
// If we start supporting MQA, we need to repeat the k and v tensors here.
|
||||
// Support for MQA, useful for 70B models.
|
||||
let k = self.repeat_kv(k)?;
|
||||
let v = self.repeat_kv(v)?;
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = mask.broadcast_as(att.shape())?;
|
||||
@ -156,6 +159,20 @@ impl LayerWeights {
|
||||
let y = self.attention_wo.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.n_head / self.n_kv_head;
|
||||
if n_rep == 1 {
|
||||
Ok(x)
|
||||
} else {
|
||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
|
||||
let x = x
|
||||
.unsqueeze(2)?
|
||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ModelWeights {
|
||||
@ -179,7 +196,7 @@ impl WeightMap {
|
||||
}
|
||||
|
||||
impl ModelWeights {
|
||||
fn new(mut ct: Content) -> Result<Self> {
|
||||
fn new(mut ct: Content, gqa: usize) -> Result<Self> {
|
||||
let cpu = &Device::Cpu;
|
||||
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
||||
|
||||
@ -226,6 +243,7 @@ impl ModelWeights {
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
|
||||
ffn_norm: RmsNorm::new(ffn_norm)?,
|
||||
n_head: ct.hparams.n_head as usize,
|
||||
n_kv_head: ct.hparams.n_head as usize / gqa,
|
||||
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
|
||||
cos: cos.clone(),
|
||||
sin: sin.clone(),
|
||||
@ -347,6 +365,10 @@ struct Args {
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "7b")]
|
||||
which: Which,
|
||||
|
||||
/// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
|
||||
#[arg(long)]
|
||||
gqa: Option<usize>,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
@ -468,7 +490,11 @@ fn main() -> anyhow::Result<()> {
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
println!("params: {:?}", model.hparams);
|
||||
let mut model = ModelWeights::new(model)?;
|
||||
let default_gqa = match args.which {
|
||||
Which::L7b | Which::L13b => 1,
|
||||
Which::L70b => 8,
|
||||
};
|
||||
let mut model = ModelWeights::new(model, args.gqa.unwrap_or(default_gqa))?;
|
||||
println!("model built");
|
||||
|
||||
let tokenizer = args.tokenizer()?;
|
||||
|
Reference in New Issue
Block a user