From 3aeb9575c7695cf6f4207bb8989fac4db13bf290 Mon Sep 17 00:00:00 2001 From: Kyle Birnbaum Date: Thu, 24 Apr 2025 20:47:48 -0700 Subject: [PATCH] Fixed Quantized Gemma3 Model and example (#2918) * removed scale factor from computation and made quantized gemma3 work similarly to non-quantized gemma3 * created default consts, replaced is_sliding with Option holding a window_size --- .../src/models/quantized_gemma3.rs | 198 +++++++++++------- 1 file changed, 119 insertions(+), 79 deletions(-) diff --git a/candle-transformers/src/models/quantized_gemma3.rs b/candle-transformers/src/models/quantized_gemma3.rs index 929f4936..bc5b9e7f 100644 --- a/candle-transformers/src/models/quantized_gemma3.rs +++ b/candle-transformers/src/models/quantized_gemma3.rs @@ -14,15 +14,18 @@ //! - [Gemma 3 Models](https://blog.google/technology/developers/gemma-3/) //! -use std::collections::HashMap; - use crate::quantized_nn::RmsNorm; use candle::quantized::gguf_file; use candle::quantized::QTensor; +use candle::D; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module}; pub const MAX_SEQ_LEN: usize = 131072; // Gemma 3 supports 128K context window +pub const DEFAULT_SLIDING_WINDOW_TYPE: usize = 6; +pub const DEFAULT_ROPE_FREQUENCY: f32 = 1_000_000.; +pub const DEFAULT_ROPE_FREQUENCY_SLIDING: f32 = 10_000.; +pub const DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR: f32 = 1.; #[derive(Debug, Clone)] struct QMatMul { @@ -61,7 +64,44 @@ impl Module for Mlp { } #[derive(Debug, Clone)] -pub struct LayerWeights { +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl RotaryEmbedding { + fn new(head_dim: usize, rope_frequency: f32, device: &Device) -> Result { + let theta: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / rope_frequency.powf(i as f32 / head_dim as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), device)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok(Self { sin, cos }) + } + + fn apply_rotary_emb_qkv( + &self, + q: &Tensor, + k: &Tensor, + index_pos: usize, + ) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let cos = self.cos.narrow(0, index_pos, seq_len)?; + let sin = self.sin.narrow(0, index_pos, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; + Ok((q_embed, k_embed)) + } +} + +#[derive(Debug, Clone)] +struct LayerWeights { // Attention components attention_wq: QMatMul, attention_wk: QMatMul, @@ -87,38 +127,54 @@ pub struct LayerWeights { head_dim: usize, // Dimension of each head q_dim: usize, // Total dimension for queries - // Rotary embedding - cos: Tensor, - sin: Tensor, + sliding_window_size: Option, + + rotary_embedding: RotaryEmbedding, neg_inf: Tensor, // Cache - pub kv_cache: Option<(Tensor, Tensor)>, + kv_cache: Option<(Tensor, Tensor)>, // Tracing span_attn: tracing::Span, span_mlp: tracing::Span, } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result { - let shape = mask.shape(); - let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?; - Ok(m) -} - impl LayerWeights { - fn apply_rotary_emb_qkv( + fn mask( &self, - q: &Tensor, - k: &Tensor, + b_sz: usize, + seq_len: usize, index_pos: usize, - ) -> Result<(Tensor, Tensor)> { - let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; - let cos = self.cos.narrow(0, index_pos, seq_len)?; - let sin = self.sin.narrow(0, index_pos, seq_len)?; - let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; - let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; - Ok((q_embed, k_embed)) + dtype: DType, + device: &Device, + ) -> Result { + let mask: Vec<_> = if let Some(sliding_window_size) = self.sliding_window_size { + (0..seq_len) + .flat_map(|i| { + (0..seq_len).map(move |j| { + if i < j || j + sliding_window_size < i { + 0u32 + } else { + 1u32 + } + }) + }) + .collect() + } else { + (0..seq_len) + .flat_map(|i| (0..seq_len).map(move |j| if i < j { 0u32 } else { 1u32 })) + .collect() + }; + let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?; + let mask = if index_pos > 0 { + let mask0 = Tensor::zeros((seq_len, index_pos), DType::F32, device)?; + Tensor::cat(&[&mask0, &mask], D::Minus1)? + } else { + mask + }; + mask.expand((b_sz, 1, seq_len, seq_len + index_pos))? + .to_dtype(dtype) } fn forward_attn( @@ -147,7 +203,9 @@ impl LayerWeights { let q = self.attention_q_norm.forward(&q.contiguous()?)?; let k = self.attention_k_norm.forward(&k.contiguous()?)?; - let (q, k) = self.apply_rotary_emb_qkv(&q, &k, index_pos)?; + let (q, k) = self + .rotary_embedding + .apply_rotary_emb_qkv(&q, &k, index_pos)?; let (k, v) = match &self.kv_cache { None => (k, v), @@ -173,7 +231,8 @@ impl LayerWeights { if let Some(mask) = mask { let mask = mask.broadcast_as(attn_weights.shape())?; - attn_weights = masked_fill(&attn_weights, &mask, &self.neg_inf)?; + let neg_inf = self.neg_inf.broadcast_as(attn_weights.dims())?; + attn_weights = mask.eq(0u32)?.where_cond(&neg_inf, &attn_weights)?; } let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; @@ -191,33 +250,13 @@ impl LayerWeights { pub struct ModelWeights { tok_embeddings: Embedding, embedding_length: usize, - pub layers: Vec, + layers: Vec, norm: RmsNorm, output: QMatMul, - masks: HashMap, span: tracing::Span, span_output: tracing::Span, } -fn precomput_freqs_cis( - head_dim: usize, - freq_base: f32, - device: &Device, -) -> Result<(Tensor, Tensor)> { - let theta: Vec<_> = (0..head_dim) - .step_by(2) - .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) - .collect(); - let theta = Tensor::new(theta.as_slice(), device)?; - let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? - .to_dtype(DType::F32)? - .reshape((MAX_SEQ_LEN, 1))? - .matmul(&theta.reshape((1, theta.elem_count()))?)?; - let cos = idx_theta.cos()?; - let sin = idx_theta.sin()?; - Ok((cos, sin)) -} - impl ModelWeights { pub fn from_gguf( ct: gguf_file::Content, @@ -236,25 +275,29 @@ impl ModelWeights { let key_length = md_get("gemma3.attention.key_length")?.to_u32()? as usize; let _value_length = md_get("gemma3.attention.value_length")?.to_u32()? as usize; let rms_norm_eps = md_get("gemma3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; + let sliding_window_size = md_get("gemma3.attention.sliding_window")?.to_u32()? as usize; + + let sliding_window_type = md_get("gemma3.attention.sliding_window_type") + .and_then(|m| Ok(m.to_u32()? as usize)) + .unwrap_or(DEFAULT_SLIDING_WINDOW_TYPE); let rope_freq_base = md_get("gemma3.rope.freq_base") .and_then(|m| m.to_f32()) - .unwrap_or(1000000f32); + .unwrap_or(DEFAULT_ROPE_FREQUENCY); - let rope_freq_scaling_factor = md_get("gemma3.rope.scaling.factor") + let rope_freq_base_sliding = md_get("gemma3.rope.local_freq_base") .and_then(|m| m.to_f32()) - .unwrap_or(8f32); + .unwrap_or(DEFAULT_ROPE_FREQUENCY_SLIDING); + + // Unused in Llama.cpp so we aren't using it here. + let _rope_freq_scaling_factor = md_get("gemma3.rope.scaling.factor") + .and_then(|m| m.to_f32()) + .unwrap_or(DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR); // Compute the dimensions for queries, keys, and values // These are the total dimensions when projected across all heads let q_dim = head_count * key_length; - // Precompute rotary embeddings - let (cos, sin) = precomput_freqs_cis( - key_length, - rope_freq_base / rope_freq_scaling_factor, - device, - )?; let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; // Load token embeddings and output projection @@ -325,6 +368,17 @@ impl ModelWeights { feed_forward_down: QMatMul::from_qtensor(feed_forward_down)?, }; + // Sliding window pattern hardcoded to 6 because it's not explicitly defined + let is_sliding = (layer_idx + 1) % sliding_window_type > 0; + let sliding_window_size = is_sliding.then_some(sliding_window_size); + let layer_rope_frequency = if is_sliding { + rope_freq_base_sliding + } else { + rope_freq_base + }; + + let rotary_embedding = RotaryEmbedding::new(key_length, layer_rope_frequency, device)?; + // Tracing spans let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); @@ -345,8 +399,8 @@ impl ModelWeights { n_kv_head: head_count_kv, head_dim: key_length, q_dim, - cos: cos.clone(), - sin: sin.clone(), + sliding_window_size, + rotary_embedding, neg_inf: neg_inf.clone(), kv_cache: None, span_attn, @@ -363,43 +417,29 @@ impl ModelWeights { layers, norm, output: QMatMul::from_qtensor(output)?, - masks: HashMap::new(), span, span_output, }) } - fn mask(&mut self, t: usize, device: &Device) -> Result { - if let Some(mask) = self.masks.get(&t) { - Ok(mask.clone()) - } else { - let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) - .collect(); - let mask = Tensor::from_slice(&mask, (t, t), device)?; - self.masks.insert(t, mask.clone()); - Ok(mask) - } - } - pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result { - let (_b_sz, seq_len) = x.dims2()?; - - let mask = if seq_len == 1 { - None - } else { - Some(self.mask(seq_len, x.device())?) - }; + let (b_sz, seq_len) = x.dims2()?; let _enter = self.span.enter(); let mut layer_in = self.tok_embeddings.forward(x)?; layer_in = (layer_in * (self.embedding_length as f64).sqrt())?; for layer in self.layers.iter_mut() { + let attention_mask = if seq_len == 1 { + None + } else { + Some(layer.mask(b_sz, seq_len, index_pos, x.dtype(), x.device())?) + }; + // Attention block let residual = &layer_in; let x = layer.attention_norm.forward(&layer_in)?; - let x = layer.forward_attn(&x, mask.as_ref(), index_pos)?; + let x = layer.forward_attn(&x, attention_mask.as_ref(), index_pos)?; let x = layer.post_attention_norm.forward(&x)?; let x = (x + residual)?;