Fixed Quantized Gemma3 Model and example (#2918)

* removed scale factor from computation and made quantized gemma3 work similarly to non-quantized gemma3

* created default consts, replaced is_sliding with Option holding a window_size
This commit is contained in:
Kyle Birnbaum
2025-04-24 20:47:48 -07:00
committed by GitHub
parent 6ff0a6999c
commit 3aeb9575c7

View File

@ -14,15 +14,18 @@
//! - [Gemma 3 Models](https://blog.google/technology/developers/gemma-3/)
//!
use std::collections::HashMap;
use crate::quantized_nn::RmsNorm;
use candle::quantized::gguf_file;
use candle::quantized::QTensor;
use candle::D;
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module};
pub const MAX_SEQ_LEN: usize = 131072; // Gemma 3 supports 128K context window
pub const DEFAULT_SLIDING_WINDOW_TYPE: usize = 6;
pub const DEFAULT_ROPE_FREQUENCY: f32 = 1_000_000.;
pub const DEFAULT_ROPE_FREQUENCY_SLIDING: f32 = 10_000.;
pub const DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR: f32 = 1.;
#[derive(Debug, Clone)]
struct QMatMul {
@ -61,7 +64,44 @@ impl Module for Mlp {
}
#[derive(Debug, Clone)]
pub struct LayerWeights {
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl RotaryEmbedding {
fn new(head_dim: usize, rope_frequency: f32, device: &Device) -> Result<Self> {
let theta: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / rope_frequency.powf(i as f32 / head_dim as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
Ok(Self { sin, cos })
}
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
index_pos: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, index_pos, seq_len)?;
let sin = self.sin.narrow(0, index_pos, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
#[derive(Debug, Clone)]
struct LayerWeights {
// Attention components
attention_wq: QMatMul,
attention_wk: QMatMul,
@ -87,38 +127,54 @@ pub struct LayerWeights {
head_dim: usize, // Dimension of each head
q_dim: usize, // Total dimension for queries
// Rotary embedding
cos: Tensor,
sin: Tensor,
sliding_window_size: Option<usize>,
rotary_embedding: RotaryEmbedding,
neg_inf: Tensor,
// Cache
pub kv_cache: Option<(Tensor, Tensor)>,
kv_cache: Option<(Tensor, Tensor)>,
// Tracing
span_attn: tracing::Span,
span_mlp: tracing::Span,
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
let shape = mask.shape();
let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
Ok(m)
}
impl LayerWeights {
fn apply_rotary_emb_qkv(
fn mask(
&self,
q: &Tensor,
k: &Tensor,
b_sz: usize,
seq_len: usize,
index_pos: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, index_pos, seq_len)?;
let sin = self.sin.narrow(0, index_pos, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
dtype: DType,
device: &Device,
) -> Result<Tensor> {
let mask: Vec<_> = if let Some(sliding_window_size) = self.sliding_window_size {
(0..seq_len)
.flat_map(|i| {
(0..seq_len).map(move |j| {
if i < j || j + sliding_window_size < i {
0u32
} else {
1u32
}
})
})
.collect()
} else {
(0..seq_len)
.flat_map(|i| (0..seq_len).map(move |j| if i < j { 0u32 } else { 1u32 }))
.collect()
};
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
let mask = if index_pos > 0 {
let mask0 = Tensor::zeros((seq_len, index_pos), DType::F32, device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_sz, 1, seq_len, seq_len + index_pos))?
.to_dtype(dtype)
}
fn forward_attn(
@ -147,7 +203,9 @@ impl LayerWeights {
let q = self.attention_q_norm.forward(&q.contiguous()?)?;
let k = self.attention_k_norm.forward(&k.contiguous()?)?;
let (q, k) = self.apply_rotary_emb_qkv(&q, &k, index_pos)?;
let (q, k) = self
.rotary_embedding
.apply_rotary_emb_qkv(&q, &k, index_pos)?;
let (k, v) = match &self.kv_cache {
None => (k, v),
@ -173,7 +231,8 @@ impl LayerWeights {
if let Some(mask) = mask {
let mask = mask.broadcast_as(attn_weights.shape())?;
attn_weights = masked_fill(&attn_weights, &mask, &self.neg_inf)?;
let neg_inf = self.neg_inf.broadcast_as(attn_weights.dims())?;
attn_weights = mask.eq(0u32)?.where_cond(&neg_inf, &attn_weights)?;
}
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
@ -191,33 +250,13 @@ impl LayerWeights {
pub struct ModelWeights {
tok_embeddings: Embedding,
embedding_length: usize,
pub layers: Vec<LayerWeights>,
layers: Vec<LayerWeights>,
norm: RmsNorm,
output: QMatMul,
masks: HashMap<usize, Tensor>,
span: tracing::Span,
span_output: tracing::Span,
}
fn precomput_freqs_cis(
head_dim: usize,
freq_base: f32,
device: &Device,
) -> Result<(Tensor, Tensor)> {
let theta: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
Ok((cos, sin))
}
impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
@ -236,25 +275,29 @@ impl ModelWeights {
let key_length = md_get("gemma3.attention.key_length")?.to_u32()? as usize;
let _value_length = md_get("gemma3.attention.value_length")?.to_u32()? as usize;
let rms_norm_eps = md_get("gemma3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
let sliding_window_size = md_get("gemma3.attention.sliding_window")?.to_u32()? as usize;
let sliding_window_type = md_get("gemma3.attention.sliding_window_type")
.and_then(|m| Ok(m.to_u32()? as usize))
.unwrap_or(DEFAULT_SLIDING_WINDOW_TYPE);
let rope_freq_base = md_get("gemma3.rope.freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(1000000f32);
.unwrap_or(DEFAULT_ROPE_FREQUENCY);
let rope_freq_scaling_factor = md_get("gemma3.rope.scaling.factor")
let rope_freq_base_sliding = md_get("gemma3.rope.local_freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(8f32);
.unwrap_or(DEFAULT_ROPE_FREQUENCY_SLIDING);
// Unused in Llama.cpp so we aren't using it here.
let _rope_freq_scaling_factor = md_get("gemma3.rope.scaling.factor")
.and_then(|m| m.to_f32())
.unwrap_or(DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR);
// Compute the dimensions for queries, keys, and values
// These are the total dimensions when projected across all heads
let q_dim = head_count * key_length;
// Precompute rotary embeddings
let (cos, sin) = precomput_freqs_cis(
key_length,
rope_freq_base / rope_freq_scaling_factor,
device,
)?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
// Load token embeddings and output projection
@ -325,6 +368,17 @@ impl ModelWeights {
feed_forward_down: QMatMul::from_qtensor(feed_forward_down)?,
};
// Sliding window pattern hardcoded to 6 because it's not explicitly defined
let is_sliding = (layer_idx + 1) % sliding_window_type > 0;
let sliding_window_size = is_sliding.then_some(sliding_window_size);
let layer_rope_frequency = if is_sliding {
rope_freq_base_sliding
} else {
rope_freq_base
};
let rotary_embedding = RotaryEmbedding::new(key_length, layer_rope_frequency, device)?;
// Tracing spans
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
@ -345,8 +399,8 @@ impl ModelWeights {
n_kv_head: head_count_kv,
head_dim: key_length,
q_dim,
cos: cos.clone(),
sin: sin.clone(),
sliding_window_size,
rotary_embedding,
neg_inf: neg_inf.clone(),
kv_cache: None,
span_attn,
@ -363,43 +417,29 @@ impl ModelWeights {
layers,
norm,
output: QMatMul::from_qtensor(output)?,
masks: HashMap::new(),
span,
span_output,
})
}
fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&t) {
Ok(mask.clone())
} else {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), device)?;
self.masks.insert(t, mask.clone());
Ok(mask)
}
}
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?;
let mask = if seq_len == 1 {
None
} else {
Some(self.mask(seq_len, x.device())?)
};
let (b_sz, seq_len) = x.dims2()?;
let _enter = self.span.enter();
let mut layer_in = self.tok_embeddings.forward(x)?;
layer_in = (layer_in * (self.embedding_length as f64).sqrt())?;
for layer in self.layers.iter_mut() {
let attention_mask = if seq_len == 1 {
None
} else {
Some(layer.mask(b_sz, seq_len, index_pos, x.dtype(), x.device())?)
};
// Attention block
let residual = &layer_in;
let x = layer.attention_norm.forward(&layer_in)?;
let x = layer.forward_attn(&x, mask.as_ref(), index_pos)?;
let x = layer.forward_attn(&x, attention_mask.as_ref(), index_pos)?;
let x = layer.post_attention_norm.forward(&x)?;
let x = (x + residual)?;