mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Fixed Quantized Gemma3 Model and example (#2918)
* removed scale factor from computation and made quantized gemma3 work similarly to non-quantized gemma3 * created default consts, replaced is_sliding with Option holding a window_size
This commit is contained in:
@ -14,15 +14,18 @@
|
|||||||
//! - [Gemma 3 Models](https://blog.google/technology/developers/gemma-3/)
|
//! - [Gemma 3 Models](https://blog.google/technology/developers/gemma-3/)
|
||||||
//!
|
//!
|
||||||
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use crate::quantized_nn::RmsNorm;
|
use crate::quantized_nn::RmsNorm;
|
||||||
use candle::quantized::gguf_file;
|
use candle::quantized::gguf_file;
|
||||||
use candle::quantized::QTensor;
|
use candle::quantized::QTensor;
|
||||||
|
use candle::D;
|
||||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||||
use candle_nn::{Embedding, Module};
|
use candle_nn::{Embedding, Module};
|
||||||
|
|
||||||
pub const MAX_SEQ_LEN: usize = 131072; // Gemma 3 supports 128K context window
|
pub const MAX_SEQ_LEN: usize = 131072; // Gemma 3 supports 128K context window
|
||||||
|
pub const DEFAULT_SLIDING_WINDOW_TYPE: usize = 6;
|
||||||
|
pub const DEFAULT_ROPE_FREQUENCY: f32 = 1_000_000.;
|
||||||
|
pub const DEFAULT_ROPE_FREQUENCY_SLIDING: f32 = 10_000.;
|
||||||
|
pub const DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR: f32 = 1.;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct QMatMul {
|
struct QMatMul {
|
||||||
@ -61,7 +64,44 @@ impl Module for Mlp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct LayerWeights {
|
struct RotaryEmbedding {
|
||||||
|
sin: Tensor,
|
||||||
|
cos: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RotaryEmbedding {
|
||||||
|
fn new(head_dim: usize, rope_frequency: f32, device: &Device) -> Result<Self> {
|
||||||
|
let theta: Vec<_> = (0..head_dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / rope_frequency.powf(i as f32 / head_dim as f32))
|
||||||
|
.collect();
|
||||||
|
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||||
|
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.reshape((MAX_SEQ_LEN, 1))?
|
||||||
|
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||||
|
let cos = idx_theta.cos()?;
|
||||||
|
let sin = idx_theta.sin()?;
|
||||||
|
Ok(Self { sin, cos })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_rotary_emb_qkv(
|
||||||
|
&self,
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
index_pos: usize,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||||
|
let cos = self.cos.narrow(0, index_pos, seq_len)?;
|
||||||
|
let sin = self.sin.narrow(0, index_pos, seq_len)?;
|
||||||
|
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||||
|
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||||
|
Ok((q_embed, k_embed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct LayerWeights {
|
||||||
// Attention components
|
// Attention components
|
||||||
attention_wq: QMatMul,
|
attention_wq: QMatMul,
|
||||||
attention_wk: QMatMul,
|
attention_wk: QMatMul,
|
||||||
@ -87,38 +127,54 @@ pub struct LayerWeights {
|
|||||||
head_dim: usize, // Dimension of each head
|
head_dim: usize, // Dimension of each head
|
||||||
q_dim: usize, // Total dimension for queries
|
q_dim: usize, // Total dimension for queries
|
||||||
|
|
||||||
// Rotary embedding
|
sliding_window_size: Option<usize>,
|
||||||
cos: Tensor,
|
|
||||||
sin: Tensor,
|
rotary_embedding: RotaryEmbedding,
|
||||||
neg_inf: Tensor,
|
neg_inf: Tensor,
|
||||||
|
|
||||||
// Cache
|
// Cache
|
||||||
pub kv_cache: Option<(Tensor, Tensor)>,
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
|
||||||
// Tracing
|
// Tracing
|
||||||
span_attn: tracing::Span,
|
span_attn: tracing::Span,
|
||||||
span_mlp: tracing::Span,
|
span_mlp: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
|
|
||||||
let shape = mask.shape();
|
|
||||||
let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
|
|
||||||
Ok(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LayerWeights {
|
impl LayerWeights {
|
||||||
fn apply_rotary_emb_qkv(
|
fn mask(
|
||||||
&self,
|
&self,
|
||||||
q: &Tensor,
|
b_sz: usize,
|
||||||
k: &Tensor,
|
seq_len: usize,
|
||||||
index_pos: usize,
|
index_pos: usize,
|
||||||
) -> Result<(Tensor, Tensor)> {
|
dtype: DType,
|
||||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
device: &Device,
|
||||||
let cos = self.cos.narrow(0, index_pos, seq_len)?;
|
) -> Result<Tensor> {
|
||||||
let sin = self.sin.narrow(0, index_pos, seq_len)?;
|
let mask: Vec<_> = if let Some(sliding_window_size) = self.sliding_window_size {
|
||||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
(0..seq_len)
|
||||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
.flat_map(|i| {
|
||||||
Ok((q_embed, k_embed))
|
(0..seq_len).map(move |j| {
|
||||||
|
if i < j || j + sliding_window_size < i {
|
||||||
|
0u32
|
||||||
|
} else {
|
||||||
|
1u32
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
(0..seq_len)
|
||||||
|
.flat_map(|i| (0..seq_len).map(move |j| if i < j { 0u32 } else { 1u32 }))
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
|
||||||
|
let mask = if index_pos > 0 {
|
||||||
|
let mask0 = Tensor::zeros((seq_len, index_pos), DType::F32, device)?;
|
||||||
|
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||||
|
} else {
|
||||||
|
mask
|
||||||
|
};
|
||||||
|
mask.expand((b_sz, 1, seq_len, seq_len + index_pos))?
|
||||||
|
.to_dtype(dtype)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward_attn(
|
fn forward_attn(
|
||||||
@ -147,7 +203,9 @@ impl LayerWeights {
|
|||||||
let q = self.attention_q_norm.forward(&q.contiguous()?)?;
|
let q = self.attention_q_norm.forward(&q.contiguous()?)?;
|
||||||
let k = self.attention_k_norm.forward(&k.contiguous()?)?;
|
let k = self.attention_k_norm.forward(&k.contiguous()?)?;
|
||||||
|
|
||||||
let (q, k) = self.apply_rotary_emb_qkv(&q, &k, index_pos)?;
|
let (q, k) = self
|
||||||
|
.rotary_embedding
|
||||||
|
.apply_rotary_emb_qkv(&q, &k, index_pos)?;
|
||||||
|
|
||||||
let (k, v) = match &self.kv_cache {
|
let (k, v) = match &self.kv_cache {
|
||||||
None => (k, v),
|
None => (k, v),
|
||||||
@ -173,7 +231,8 @@ impl LayerWeights {
|
|||||||
|
|
||||||
if let Some(mask) = mask {
|
if let Some(mask) = mask {
|
||||||
let mask = mask.broadcast_as(attn_weights.shape())?;
|
let mask = mask.broadcast_as(attn_weights.shape())?;
|
||||||
attn_weights = masked_fill(&attn_weights, &mask, &self.neg_inf)?;
|
let neg_inf = self.neg_inf.broadcast_as(attn_weights.dims())?;
|
||||||
|
attn_weights = mask.eq(0u32)?.where_cond(&neg_inf, &attn_weights)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||||
@ -191,33 +250,13 @@ impl LayerWeights {
|
|||||||
pub struct ModelWeights {
|
pub struct ModelWeights {
|
||||||
tok_embeddings: Embedding,
|
tok_embeddings: Embedding,
|
||||||
embedding_length: usize,
|
embedding_length: usize,
|
||||||
pub layers: Vec<LayerWeights>,
|
layers: Vec<LayerWeights>,
|
||||||
norm: RmsNorm,
|
norm: RmsNorm,
|
||||||
output: QMatMul,
|
output: QMatMul,
|
||||||
masks: HashMap<usize, Tensor>,
|
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
span_output: tracing::Span,
|
span_output: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn precomput_freqs_cis(
|
|
||||||
head_dim: usize,
|
|
||||||
freq_base: f32,
|
|
||||||
device: &Device,
|
|
||||||
) -> Result<(Tensor, Tensor)> {
|
|
||||||
let theta: Vec<_> = (0..head_dim)
|
|
||||||
.step_by(2)
|
|
||||||
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
|
||||||
.collect();
|
|
||||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
|
||||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
|
||||||
.to_dtype(DType::F32)?
|
|
||||||
.reshape((MAX_SEQ_LEN, 1))?
|
|
||||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
|
||||||
let cos = idx_theta.cos()?;
|
|
||||||
let sin = idx_theta.sin()?;
|
|
||||||
Ok((cos, sin))
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ModelWeights {
|
impl ModelWeights {
|
||||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||||
ct: gguf_file::Content,
|
ct: gguf_file::Content,
|
||||||
@ -236,25 +275,29 @@ impl ModelWeights {
|
|||||||
let key_length = md_get("gemma3.attention.key_length")?.to_u32()? as usize;
|
let key_length = md_get("gemma3.attention.key_length")?.to_u32()? as usize;
|
||||||
let _value_length = md_get("gemma3.attention.value_length")?.to_u32()? as usize;
|
let _value_length = md_get("gemma3.attention.value_length")?.to_u32()? as usize;
|
||||||
let rms_norm_eps = md_get("gemma3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
let rms_norm_eps = md_get("gemma3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
||||||
|
let sliding_window_size = md_get("gemma3.attention.sliding_window")?.to_u32()? as usize;
|
||||||
|
|
||||||
|
let sliding_window_type = md_get("gemma3.attention.sliding_window_type")
|
||||||
|
.and_then(|m| Ok(m.to_u32()? as usize))
|
||||||
|
.unwrap_or(DEFAULT_SLIDING_WINDOW_TYPE);
|
||||||
|
|
||||||
let rope_freq_base = md_get("gemma3.rope.freq_base")
|
let rope_freq_base = md_get("gemma3.rope.freq_base")
|
||||||
.and_then(|m| m.to_f32())
|
.and_then(|m| m.to_f32())
|
||||||
.unwrap_or(1000000f32);
|
.unwrap_or(DEFAULT_ROPE_FREQUENCY);
|
||||||
|
|
||||||
let rope_freq_scaling_factor = md_get("gemma3.rope.scaling.factor")
|
let rope_freq_base_sliding = md_get("gemma3.rope.local_freq_base")
|
||||||
.and_then(|m| m.to_f32())
|
.and_then(|m| m.to_f32())
|
||||||
.unwrap_or(8f32);
|
.unwrap_or(DEFAULT_ROPE_FREQUENCY_SLIDING);
|
||||||
|
|
||||||
|
// Unused in Llama.cpp so we aren't using it here.
|
||||||
|
let _rope_freq_scaling_factor = md_get("gemma3.rope.scaling.factor")
|
||||||
|
.and_then(|m| m.to_f32())
|
||||||
|
.unwrap_or(DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR);
|
||||||
|
|
||||||
// Compute the dimensions for queries, keys, and values
|
// Compute the dimensions for queries, keys, and values
|
||||||
// These are the total dimensions when projected across all heads
|
// These are the total dimensions when projected across all heads
|
||||||
let q_dim = head_count * key_length;
|
let q_dim = head_count * key_length;
|
||||||
|
|
||||||
// Precompute rotary embeddings
|
|
||||||
let (cos, sin) = precomput_freqs_cis(
|
|
||||||
key_length,
|
|
||||||
rope_freq_base / rope_freq_scaling_factor,
|
|
||||||
device,
|
|
||||||
)?;
|
|
||||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||||
|
|
||||||
// Load token embeddings and output projection
|
// Load token embeddings and output projection
|
||||||
@ -325,6 +368,17 @@ impl ModelWeights {
|
|||||||
feed_forward_down: QMatMul::from_qtensor(feed_forward_down)?,
|
feed_forward_down: QMatMul::from_qtensor(feed_forward_down)?,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Sliding window pattern hardcoded to 6 because it's not explicitly defined
|
||||||
|
let is_sliding = (layer_idx + 1) % sliding_window_type > 0;
|
||||||
|
let sliding_window_size = is_sliding.then_some(sliding_window_size);
|
||||||
|
let layer_rope_frequency = if is_sliding {
|
||||||
|
rope_freq_base_sliding
|
||||||
|
} else {
|
||||||
|
rope_freq_base
|
||||||
|
};
|
||||||
|
|
||||||
|
let rotary_embedding = RotaryEmbedding::new(key_length, layer_rope_frequency, device)?;
|
||||||
|
|
||||||
// Tracing spans
|
// Tracing spans
|
||||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||||
@ -345,8 +399,8 @@ impl ModelWeights {
|
|||||||
n_kv_head: head_count_kv,
|
n_kv_head: head_count_kv,
|
||||||
head_dim: key_length,
|
head_dim: key_length,
|
||||||
q_dim,
|
q_dim,
|
||||||
cos: cos.clone(),
|
sliding_window_size,
|
||||||
sin: sin.clone(),
|
rotary_embedding,
|
||||||
neg_inf: neg_inf.clone(),
|
neg_inf: neg_inf.clone(),
|
||||||
kv_cache: None,
|
kv_cache: None,
|
||||||
span_attn,
|
span_attn,
|
||||||
@ -363,43 +417,29 @@ impl ModelWeights {
|
|||||||
layers,
|
layers,
|
||||||
norm,
|
norm,
|
||||||
output: QMatMul::from_qtensor(output)?,
|
output: QMatMul::from_qtensor(output)?,
|
||||||
masks: HashMap::new(),
|
|
||||||
span,
|
span,
|
||||||
span_output,
|
span_output,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
|
|
||||||
if let Some(mask) = self.masks.get(&t) {
|
|
||||||
Ok(mask.clone())
|
|
||||||
} else {
|
|
||||||
let mask: Vec<_> = (0..t)
|
|
||||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
|
||||||
.collect();
|
|
||||||
let mask = Tensor::from_slice(&mask, (t, t), device)?;
|
|
||||||
self.masks.insert(t, mask.clone());
|
|
||||||
Ok(mask)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
let (_b_sz, seq_len) = x.dims2()?;
|
let (b_sz, seq_len) = x.dims2()?;
|
||||||
|
|
||||||
let mask = if seq_len == 1 {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(self.mask(seq_len, x.device())?)
|
|
||||||
};
|
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
|
|
||||||
let mut layer_in = self.tok_embeddings.forward(x)?;
|
let mut layer_in = self.tok_embeddings.forward(x)?;
|
||||||
layer_in = (layer_in * (self.embedding_length as f64).sqrt())?;
|
layer_in = (layer_in * (self.embedding_length as f64).sqrt())?;
|
||||||
|
|
||||||
for layer in self.layers.iter_mut() {
|
for layer in self.layers.iter_mut() {
|
||||||
|
let attention_mask = if seq_len == 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(layer.mask(b_sz, seq_len, index_pos, x.dtype(), x.device())?)
|
||||||
|
};
|
||||||
|
|
||||||
// Attention block
|
// Attention block
|
||||||
let residual = &layer_in;
|
let residual = &layer_in;
|
||||||
let x = layer.attention_norm.forward(&layer_in)?;
|
let x = layer.attention_norm.forward(&layer_in)?;
|
||||||
let x = layer.forward_attn(&x, mask.as_ref(), index_pos)?;
|
let x = layer.forward_attn(&x, attention_mask.as_ref(), index_pos)?;
|
||||||
let x = layer.post_attention_norm.forward(&x)?;
|
let x = layer.post_attention_norm.forward(&x)?;
|
||||||
let x = (x + residual)?;
|
let x = (x + residual)?;
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user