mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
implemented quantized-gemma3 (#2902)
* implemented quantized-gemma, inference not working * Fixed a few modeling bugs: outputing the correct tokens for a few iterations then garbage * lint * clippy * quantized-gemma3 example working * added readme * clippy
This commit is contained in:
@ -79,6 +79,7 @@ pub mod phi3;
|
||||
pub mod pixtral;
|
||||
pub mod quantized_blip;
|
||||
pub mod quantized_blip_text;
|
||||
pub mod quantized_gemma3;
|
||||
pub mod quantized_llama;
|
||||
pub mod quantized_llama2_c;
|
||||
pub mod quantized_metavoice;
|
||||
|
418
candle-transformers/src/models/quantized_gemma3.rs
Normal file
418
candle-transformers/src/models/quantized_gemma3.rs
Normal file
@ -0,0 +1,418 @@
|
||||
//! Gemma 3 model implementation with quantization support.
|
||||
//!
|
||||
//! Gemma 3 is a family of multimodal language models developed by Google.
|
||||
//! This implementation provides quantization for reduced memory usage and faster inference.
|
||||
//!
|
||||
//! Key characteristics:
|
||||
//! - Group-Query Attention (GQA) with specialized key-value heads
|
||||
//! - RMSNorm for layer normalization
|
||||
//! - Specialized attention patterns with separate normalization for Q/K/V
|
||||
//! - Feed-forward network with SwiGLU activation
|
||||
//! - Support for 2/3/4/8-bit quantization
|
||||
//!
|
||||
//! References:
|
||||
//! - [Gemma 3 Models](https://blog.google/technology/developers/gemma-3/)
|
||||
//!
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::quantized_nn::RmsNorm;
|
||||
use candle::quantized::gguf_file;
|
||||
use candle::quantized::QTensor;
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{Embedding, Module};
|
||||
|
||||
pub const MAX_SEQ_LEN: usize = 131072; // Gemma 3 supports 128K context window
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct QMatMul {
|
||||
inner: candle::quantized::QMatMul,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||
let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
feed_forward_gate: QMatMul, // ffn_gate in GGUF
|
||||
feed_forward_up: QMatMul, // ffn_up in GGUF
|
||||
feed_forward_down: QMatMul, // ffn_down in GGUF
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let gate = self.feed_forward_gate.forward(xs)?;
|
||||
let up = self.feed_forward_up.forward(xs)?;
|
||||
let silu = candle_nn::ops::silu(&gate)?;
|
||||
let gated = (silu * up)?;
|
||||
self.feed_forward_down.forward(&gated)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LayerWeights {
|
||||
// Attention components
|
||||
attention_wq: QMatMul,
|
||||
attention_wk: QMatMul,
|
||||
attention_wv: QMatMul,
|
||||
attention_wo: QMatMul,
|
||||
|
||||
// Specialized normalization for Q and K
|
||||
attention_q_norm: RmsNorm,
|
||||
attention_k_norm: RmsNorm,
|
||||
|
||||
// Layer normalization
|
||||
attention_norm: RmsNorm, // Applied before attention
|
||||
post_attention_norm: RmsNorm, // Applied after attention
|
||||
ffn_norm: RmsNorm, // Applied before feedforward
|
||||
post_ffn_norm: RmsNorm, // Applied after feedforward
|
||||
|
||||
// Feed-forward network
|
||||
mlp: Mlp,
|
||||
|
||||
// Attention parameters
|
||||
n_head: usize, // Number of query heads
|
||||
n_kv_head: usize, // Number of key-value heads
|
||||
head_dim: usize, // Dimension of each head
|
||||
q_dim: usize, // Total dimension for queries
|
||||
|
||||
// Rotary embedding
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
neg_inf: Tensor,
|
||||
|
||||
// Cache
|
||||
pub kv_cache: Option<(Tensor, Tensor)>,
|
||||
|
||||
// Tracing
|
||||
span_attn: tracing::Span,
|
||||
span_mlp: tracing::Span,
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
impl LayerWeights {
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
index_pos: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.sin.narrow(0, index_pos, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
|
||||
fn forward_attn(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
mask: Option<&Tensor>,
|
||||
index_pos: usize,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span_attn.enter();
|
||||
let (b_sz, seq_len, _) = x.dims3()?;
|
||||
|
||||
let q = self.attention_wq.forward(x)?;
|
||||
let k = self.attention_wk.forward(x)?;
|
||||
let v = self.attention_wv.forward(x)?;
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q = self.attention_q_norm.forward(&q.contiguous()?)?;
|
||||
let k = self.attention_k_norm.forward(&k.contiguous()?)?;
|
||||
|
||||
let (q, k) = self.apply_rotary_emb_qkv(&q, &k, index_pos)?;
|
||||
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k, v),
|
||||
Some((k_cache, v_cache)) => {
|
||||
if index_pos == 0 {
|
||||
(k, v)
|
||||
} else {
|
||||
let k = Tensor::cat(&[k_cache, &k], 2)?; // concat on seq dim
|
||||
let v = Tensor::cat(&[v_cache, &v], 2)?;
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone())); // update cache
|
||||
|
||||
// Repeat KV for GQA
|
||||
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
|
||||
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
|
||||
|
||||
// Scaled Dot-Product Attention
|
||||
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
||||
let mut attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
if let Some(mask) = mask {
|
||||
let mask = mask.broadcast_as(attn_weights.shape())?;
|
||||
attn_weights = masked_fill(&attn_weights, &mask, &self.neg_inf)?;
|
||||
}
|
||||
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
|
||||
let attn_output = attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, seq_len, self.q_dim))?;
|
||||
|
||||
self.attention_wo.forward(&attn_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelWeights {
|
||||
tok_embeddings: Embedding,
|
||||
embedding_length: usize,
|
||||
pub layers: Vec<LayerWeights>,
|
||||
norm: RmsNorm,
|
||||
output: QMatMul,
|
||||
masks: HashMap<usize, Tensor>,
|
||||
span: tracing::Span,
|
||||
span_output: tracing::Span,
|
||||
}
|
||||
|
||||
fn precomput_freqs_cis(
|
||||
head_dim: usize,
|
||||
freq_base: f32,
|
||||
device: &Device,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let theta: Vec<_> = (0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let cos = idx_theta.cos()?;
|
||||
let sin = idx_theta.sin()?;
|
||||
Ok((cos, sin))
|
||||
}
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||
ct: gguf_file::Content,
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||
None => candle::bail!("cannot find {s} in metadata"),
|
||||
Some(v) => Ok(v),
|
||||
};
|
||||
|
||||
let head_count = md_get("gemma3.attention.head_count")?.to_u32()? as usize;
|
||||
let head_count_kv = md_get("gemma3.attention.head_count_kv")?.to_u32()? as usize;
|
||||
let block_count = md_get("gemma3.block_count")?.to_u32()? as usize;
|
||||
let embedding_length = md_get("gemma3.embedding_length")?.to_u32()? as usize;
|
||||
let key_length = md_get("gemma3.attention.key_length")?.to_u32()? as usize;
|
||||
let _value_length = md_get("gemma3.attention.value_length")?.to_u32()? as usize;
|
||||
let rms_norm_eps = md_get("gemma3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
||||
|
||||
let rope_freq_base = md_get("gemma3.rope.freq_base")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(1000000f32);
|
||||
|
||||
// Compute the dimensions for queries, keys, and values
|
||||
// These are the total dimensions when projected across all heads
|
||||
let q_dim = head_count * key_length;
|
||||
|
||||
// Precompute rotary embeddings
|
||||
let (cos, sin) = precomput_freqs_cis(key_length, rope_freq_base, device)?;
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||
|
||||
// Load token embeddings and output projection
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
let norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, "output_norm.weight", device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
let output = match ct.tensor(reader, "output.weight", device) {
|
||||
Ok(tensor) => tensor,
|
||||
Err(_) => ct.tensor(reader, "token_embd.weight", device)?, // Use tied weights if output.weight doesn't exist
|
||||
};
|
||||
|
||||
let mut layers = Vec::with_capacity(block_count);
|
||||
for layer_idx in 0..block_count {
|
||||
let prefix = format!("blk.{layer_idx}");
|
||||
|
||||
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
|
||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
|
||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
|
||||
let attention_wo =
|
||||
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
|
||||
|
||||
let attention_q_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.attn_q_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let attention_k_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.attn_k_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let attention_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let post_attention_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(
|
||||
reader,
|
||||
&format!("{prefix}.post_attention_norm.weight"),
|
||||
device,
|
||||
)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let ffn_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let post_ffn_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.post_ffw_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let feed_forward_gate =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
||||
let feed_forward_up = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||
let feed_forward_down =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||
|
||||
let mlp = Mlp {
|
||||
feed_forward_gate: QMatMul::from_qtensor(feed_forward_gate)?,
|
||||
feed_forward_up: QMatMul::from_qtensor(feed_forward_up)?,
|
||||
feed_forward_down: QMatMul::from_qtensor(feed_forward_down)?,
|
||||
};
|
||||
|
||||
// Tracing spans
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||
|
||||
layers.push(LayerWeights {
|
||||
attention_wq: QMatMul::from_qtensor(attention_wq)?,
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk)?,
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||
attention_q_norm,
|
||||
attention_k_norm,
|
||||
attention_norm,
|
||||
post_attention_norm,
|
||||
ffn_norm,
|
||||
post_ffn_norm,
|
||||
mlp,
|
||||
n_head: head_count,
|
||||
n_kv_head: head_count_kv,
|
||||
head_dim: key_length,
|
||||
q_dim,
|
||||
cos: cos.clone(),
|
||||
sin: sin.clone(),
|
||||
neg_inf: neg_inf.clone(),
|
||||
kv_cache: None,
|
||||
span_attn,
|
||||
span_mlp,
|
||||
})
|
||||
}
|
||||
|
||||
let span = tracing::span!(tracing::Level::TRACE, "model");
|
||||
let span_output = tracing::span!(tracing::Level::TRACE, "output");
|
||||
|
||||
Ok(Self {
|
||||
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
|
||||
embedding_length,
|
||||
layers,
|
||||
norm,
|
||||
output: QMatMul::from_qtensor(output)?,
|
||||
masks: HashMap::new(),
|
||||
span,
|
||||
span_output,
|
||||
})
|
||||
}
|
||||
|
||||
fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
|
||||
if let Some(mask) = self.masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
} else {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), device)?;
|
||||
self.masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.dims2()?;
|
||||
|
||||
let mask = if seq_len == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(self.mask(seq_len, x.device())?)
|
||||
};
|
||||
let _enter = self.span.enter();
|
||||
|
||||
let mut layer_in = self.tok_embeddings.forward(x)?;
|
||||
layer_in = (layer_in * (self.embedding_length as f64).sqrt())?;
|
||||
|
||||
for layer in self.layers.iter_mut() {
|
||||
// Attention block
|
||||
let residual = &layer_in;
|
||||
let x = layer.attention_norm.forward(&layer_in)?;
|
||||
let x = layer.forward_attn(&x, mask.as_ref(), index_pos)?;
|
||||
let x = layer.post_attention_norm.forward(&x)?;
|
||||
let x = (x + residual)?;
|
||||
|
||||
// Feed-forward block
|
||||
let _enter = layer.span_mlp.enter();
|
||||
let residual = &x;
|
||||
let x = layer.ffn_norm.forward(&x)?;
|
||||
let x = layer.mlp.forward(&x)?;
|
||||
let x = layer.post_ffn_norm.forward(&x)?;
|
||||
let x = (x + residual)?;
|
||||
drop(_enter);
|
||||
|
||||
layer_in = x;
|
||||
}
|
||||
|
||||
let _enter = self.span_output.enter();
|
||||
|
||||
let x = layer_in.i((.., seq_len - 1, ..))?;
|
||||
let x = self.norm.forward(&x)?;
|
||||
let output = self.output.forward(&x)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user