mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
GQA support in the quantized model. (#555)
* GQA support in the quantized model. * Fix the reshaping. * Fix the main llama model. * Infer the proper gqa from the model kind.
This commit is contained in:
@ -291,7 +291,7 @@ impl CausalSelfAttention {
|
|||||||
let x = x
|
let x = x
|
||||||
.unsqueeze(2)?
|
.unsqueeze(2)?
|
||||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||||
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
|
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
|
||||||
Ok(x)
|
Ok(x)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -68,6 +68,7 @@ struct LayerWeights {
|
|||||||
feed_forward_w3: QMatMul,
|
feed_forward_w3: QMatMul,
|
||||||
ffn_norm: RmsNorm,
|
ffn_norm: RmsNorm,
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
|
n_kv_head: usize,
|
||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
sin: Tensor,
|
sin: Tensor,
|
||||||
@ -125,10 +126,10 @@ impl LayerWeights {
|
|||||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||||
.transpose(1, 2)?;
|
.transpose(1, 2)?;
|
||||||
let k = k
|
let k = k
|
||||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
|
||||||
.transpose(1, 2)?;
|
.transpose(1, 2)?;
|
||||||
let v = v
|
let v = v
|
||||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
|
||||||
.transpose(1, 2)?;
|
.transpose(1, 2)?;
|
||||||
|
|
||||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||||
@ -144,7 +145,9 @@ impl LayerWeights {
|
|||||||
};
|
};
|
||||||
self.kv_cache = Some((k.clone(), v.clone()));
|
self.kv_cache = Some((k.clone(), v.clone()));
|
||||||
|
|
||||||
// If we start supporting MQA, we need to repeat the k and v tensors here.
|
// Support for MQA, useful for 70B models.
|
||||||
|
let k = self.repeat_kv(k)?;
|
||||||
|
let v = self.repeat_kv(v)?;
|
||||||
|
|
||||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||||
let mask = mask.broadcast_as(att.shape())?;
|
let mask = mask.broadcast_as(att.shape())?;
|
||||||
@ -156,6 +159,20 @@ impl LayerWeights {
|
|||||||
let y = self.attention_wo.forward(&y)?;
|
let y = self.attention_wo.forward(&y)?;
|
||||||
Ok(y)
|
Ok(y)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||||
|
let n_rep = self.n_head / self.n_kv_head;
|
||||||
|
if n_rep == 1 {
|
||||||
|
Ok(x)
|
||||||
|
} else {
|
||||||
|
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
|
||||||
|
let x = x
|
||||||
|
.unsqueeze(2)?
|
||||||
|
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||||
|
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
|
||||||
|
Ok(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ModelWeights {
|
struct ModelWeights {
|
||||||
@ -179,7 +196,7 @@ impl WeightMap {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ModelWeights {
|
impl ModelWeights {
|
||||||
fn new(mut ct: Content) -> Result<Self> {
|
fn new(mut ct: Content, gqa: usize) -> Result<Self> {
|
||||||
let cpu = &Device::Cpu;
|
let cpu = &Device::Cpu;
|
||||||
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
||||||
|
|
||||||
@ -226,6 +243,7 @@ impl ModelWeights {
|
|||||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
|
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
|
||||||
ffn_norm: RmsNorm::new(ffn_norm)?,
|
ffn_norm: RmsNorm::new(ffn_norm)?,
|
||||||
n_head: ct.hparams.n_head as usize,
|
n_head: ct.hparams.n_head as usize,
|
||||||
|
n_kv_head: ct.hparams.n_head as usize / gqa,
|
||||||
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
|
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
|
||||||
cos: cos.clone(),
|
cos: cos.clone(),
|
||||||
sin: sin.clone(),
|
sin: sin.clone(),
|
||||||
@ -347,6 +365,10 @@ struct Args {
|
|||||||
/// The model size to use.
|
/// The model size to use.
|
||||||
#[arg(long, default_value = "7b")]
|
#[arg(long, default_value = "7b")]
|
||||||
which: Which,
|
which: Which,
|
||||||
|
|
||||||
|
/// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
|
||||||
|
#[arg(long)]
|
||||||
|
gqa: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
@ -468,7 +490,11 @@ fn main() -> anyhow::Result<()> {
|
|||||||
start.elapsed().as_secs_f32(),
|
start.elapsed().as_secs_f32(),
|
||||||
);
|
);
|
||||||
println!("params: {:?}", model.hparams);
|
println!("params: {:?}", model.hparams);
|
||||||
let mut model = ModelWeights::new(model)?;
|
let default_gqa = match args.which {
|
||||||
|
Which::L7b | Which::L13b => 1,
|
||||||
|
Which::L70b => 8,
|
||||||
|
};
|
||||||
|
let mut model = ModelWeights::new(model, args.gqa.unwrap_or(default_gqa))?;
|
||||||
println!("model built");
|
println!("model built");
|
||||||
|
|
||||||
let tokenizer = args.tokenizer()?;
|
let tokenizer = args.tokenizer()?;
|
||||||
|
Reference in New Issue
Block a user