mirror of
https://github.com/huggingface/candle.git
synced 2025-06-14 09:57:10 +00:00
Fixed Quantized Gemma3 Model and example (#2918)
* removed scale factor from computation and made quantized gemma3 work similarly to non-quantized gemma3 * created default consts, replaced is_sliding with Option holding a window_size
This commit is contained in:
@ -14,15 +14,18 @@
|
||||
//! - [Gemma 3 Models](https://blog.google/technology/developers/gemma-3/)
|
||||
//!
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::quantized_nn::RmsNorm;
|
||||
use candle::quantized::gguf_file;
|
||||
use candle::quantized::QTensor;
|
||||
use candle::D;
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{Embedding, Module};
|
||||
|
||||
pub const MAX_SEQ_LEN: usize = 131072; // Gemma 3 supports 128K context window
|
||||
pub const DEFAULT_SLIDING_WINDOW_TYPE: usize = 6;
|
||||
pub const DEFAULT_ROPE_FREQUENCY: f32 = 1_000_000.;
|
||||
pub const DEFAULT_ROPE_FREQUENCY_SLIDING: f32 = 10_000.;
|
||||
pub const DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR: f32 = 1.;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct QMatMul {
|
||||
@ -61,7 +64,44 @@ impl Module for Mlp {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LayerWeights {
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(head_dim: usize, rope_frequency: f32, device: &Device) -> Result<Self> {
|
||||
let theta: Vec<_> = (0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / rope_frequency.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let cos = idx_theta.cos()?;
|
||||
let sin = idx_theta.sin()?;
|
||||
Ok(Self { sin, cos })
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
index_pos: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.sin.narrow(0, index_pos, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct LayerWeights {
|
||||
// Attention components
|
||||
attention_wq: QMatMul,
|
||||
attention_wk: QMatMul,
|
||||
@ -87,38 +127,54 @@ pub struct LayerWeights {
|
||||
head_dim: usize, // Dimension of each head
|
||||
q_dim: usize, // Total dimension for queries
|
||||
|
||||
// Rotary embedding
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
sliding_window_size: Option<usize>,
|
||||
|
||||
rotary_embedding: RotaryEmbedding,
|
||||
neg_inf: Tensor,
|
||||
|
||||
// Cache
|
||||
pub kv_cache: Option<(Tensor, Tensor)>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
|
||||
// Tracing
|
||||
span_attn: tracing::Span,
|
||||
span_mlp: tracing::Span,
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
impl LayerWeights {
|
||||
fn apply_rotary_emb_qkv(
|
||||
fn mask(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
b_sz: usize,
|
||||
seq_len: usize,
|
||||
index_pos: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.sin.narrow(0, index_pos, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = if let Some(sliding_window_size) = self.sliding_window_size {
|
||||
(0..seq_len)
|
||||
.flat_map(|i| {
|
||||
(0..seq_len).map(move |j| {
|
||||
if i < j || j + sliding_window_size < i {
|
||||
0u32
|
||||
} else {
|
||||
1u32
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
(0..seq_len)
|
||||
.flat_map(|i| (0..seq_len).map(move |j| if i < j { 0u32 } else { 1u32 }))
|
||||
.collect()
|
||||
};
|
||||
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
|
||||
let mask = if index_pos > 0 {
|
||||
let mask0 = Tensor::zeros((seq_len, index_pos), DType::F32, device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_sz, 1, seq_len, seq_len + index_pos))?
|
||||
.to_dtype(dtype)
|
||||
}
|
||||
|
||||
fn forward_attn(
|
||||
@ -147,7 +203,9 @@ impl LayerWeights {
|
||||
let q = self.attention_q_norm.forward(&q.contiguous()?)?;
|
||||
let k = self.attention_k_norm.forward(&k.contiguous()?)?;
|
||||
|
||||
let (q, k) = self.apply_rotary_emb_qkv(&q, &k, index_pos)?;
|
||||
let (q, k) = self
|
||||
.rotary_embedding
|
||||
.apply_rotary_emb_qkv(&q, &k, index_pos)?;
|
||||
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k, v),
|
||||
@ -173,7 +231,8 @@ impl LayerWeights {
|
||||
|
||||
if let Some(mask) = mask {
|
||||
let mask = mask.broadcast_as(attn_weights.shape())?;
|
||||
attn_weights = masked_fill(&attn_weights, &mask, &self.neg_inf)?;
|
||||
let neg_inf = self.neg_inf.broadcast_as(attn_weights.dims())?;
|
||||
attn_weights = mask.eq(0u32)?.where_cond(&neg_inf, &attn_weights)?;
|
||||
}
|
||||
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
@ -191,33 +250,13 @@ impl LayerWeights {
|
||||
pub struct ModelWeights {
|
||||
tok_embeddings: Embedding,
|
||||
embedding_length: usize,
|
||||
pub layers: Vec<LayerWeights>,
|
||||
layers: Vec<LayerWeights>,
|
||||
norm: RmsNorm,
|
||||
output: QMatMul,
|
||||
masks: HashMap<usize, Tensor>,
|
||||
span: tracing::Span,
|
||||
span_output: tracing::Span,
|
||||
}
|
||||
|
||||
fn precomput_freqs_cis(
|
||||
head_dim: usize,
|
||||
freq_base: f32,
|
||||
device: &Device,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let theta: Vec<_> = (0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let cos = idx_theta.cos()?;
|
||||
let sin = idx_theta.sin()?;
|
||||
Ok((cos, sin))
|
||||
}
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||
ct: gguf_file::Content,
|
||||
@ -236,25 +275,29 @@ impl ModelWeights {
|
||||
let key_length = md_get("gemma3.attention.key_length")?.to_u32()? as usize;
|
||||
let _value_length = md_get("gemma3.attention.value_length")?.to_u32()? as usize;
|
||||
let rms_norm_eps = md_get("gemma3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
||||
let sliding_window_size = md_get("gemma3.attention.sliding_window")?.to_u32()? as usize;
|
||||
|
||||
let sliding_window_type = md_get("gemma3.attention.sliding_window_type")
|
||||
.and_then(|m| Ok(m.to_u32()? as usize))
|
||||
.unwrap_or(DEFAULT_SLIDING_WINDOW_TYPE);
|
||||
|
||||
let rope_freq_base = md_get("gemma3.rope.freq_base")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(1000000f32);
|
||||
.unwrap_or(DEFAULT_ROPE_FREQUENCY);
|
||||
|
||||
let rope_freq_scaling_factor = md_get("gemma3.rope.scaling.factor")
|
||||
let rope_freq_base_sliding = md_get("gemma3.rope.local_freq_base")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(8f32);
|
||||
.unwrap_or(DEFAULT_ROPE_FREQUENCY_SLIDING);
|
||||
|
||||
// Unused in Llama.cpp so we aren't using it here.
|
||||
let _rope_freq_scaling_factor = md_get("gemma3.rope.scaling.factor")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR);
|
||||
|
||||
// Compute the dimensions for queries, keys, and values
|
||||
// These are the total dimensions when projected across all heads
|
||||
let q_dim = head_count * key_length;
|
||||
|
||||
// Precompute rotary embeddings
|
||||
let (cos, sin) = precomput_freqs_cis(
|
||||
key_length,
|
||||
rope_freq_base / rope_freq_scaling_factor,
|
||||
device,
|
||||
)?;
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||
|
||||
// Load token embeddings and output projection
|
||||
@ -325,6 +368,17 @@ impl ModelWeights {
|
||||
feed_forward_down: QMatMul::from_qtensor(feed_forward_down)?,
|
||||
};
|
||||
|
||||
// Sliding window pattern hardcoded to 6 because it's not explicitly defined
|
||||
let is_sliding = (layer_idx + 1) % sliding_window_type > 0;
|
||||
let sliding_window_size = is_sliding.then_some(sliding_window_size);
|
||||
let layer_rope_frequency = if is_sliding {
|
||||
rope_freq_base_sliding
|
||||
} else {
|
||||
rope_freq_base
|
||||
};
|
||||
|
||||
let rotary_embedding = RotaryEmbedding::new(key_length, layer_rope_frequency, device)?;
|
||||
|
||||
// Tracing spans
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||
@ -345,8 +399,8 @@ impl ModelWeights {
|
||||
n_kv_head: head_count_kv,
|
||||
head_dim: key_length,
|
||||
q_dim,
|
||||
cos: cos.clone(),
|
||||
sin: sin.clone(),
|
||||
sliding_window_size,
|
||||
rotary_embedding,
|
||||
neg_inf: neg_inf.clone(),
|
||||
kv_cache: None,
|
||||
span_attn,
|
||||
@ -363,43 +417,29 @@ impl ModelWeights {
|
||||
layers,
|
||||
norm,
|
||||
output: QMatMul::from_qtensor(output)?,
|
||||
masks: HashMap::new(),
|
||||
span,
|
||||
span_output,
|
||||
})
|
||||
}
|
||||
|
||||
fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
|
||||
if let Some(mask) = self.masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
} else {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), device)?;
|
||||
self.masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.dims2()?;
|
||||
|
||||
let mask = if seq_len == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(self.mask(seq_len, x.device())?)
|
||||
};
|
||||
let (b_sz, seq_len) = x.dims2()?;
|
||||
let _enter = self.span.enter();
|
||||
|
||||
let mut layer_in = self.tok_embeddings.forward(x)?;
|
||||
layer_in = (layer_in * (self.embedding_length as f64).sqrt())?;
|
||||
|
||||
for layer in self.layers.iter_mut() {
|
||||
let attention_mask = if seq_len == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(layer.mask(b_sz, seq_len, index_pos, x.dtype(), x.device())?)
|
||||
};
|
||||
|
||||
// Attention block
|
||||
let residual = &layer_in;
|
||||
let x = layer.attention_norm.forward(&layer_in)?;
|
||||
let x = layer.forward_attn(&x, mask.as_ref(), index_pos)?;
|
||||
let x = layer.forward_attn(&x, attention_mask.as_ref(), index_pos)?;
|
||||
let x = layer.post_attention_norm.forward(&x)?;
|
||||
let x = (x + residual)?;
|
||||
|
||||
|
Reference in New Issue
Block a user