From 99bd69f3831efbbf4a5553dbd684d9156161eca0 Mon Sep 17 00:00:00 2001 From: Kyle Birnbaum Date: Tue, 22 Apr 2025 20:39:03 -0700 Subject: [PATCH] fixed quantized-gemma example (#2914) * fixed quantized-gemma example * lint --- candle-examples/examples/quantized-gemma/main.rs | 2 +- candle-transformers/src/models/quantized_gemma3.rs | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/candle-examples/examples/quantized-gemma/main.rs b/candle-examples/examples/quantized-gemma/main.rs index 543acde5..48f4b1dc 100644 --- a/candle-examples/examples/quantized-gemma/main.rs +++ b/candle-examples/examples/quantized-gemma/main.rs @@ -224,7 +224,7 @@ fn main() -> anyhow::Result<()> { } } // Format for Gemma 3 chat/instruction format - format!("user\n{prompt}\n\nmodel\n") + format!(" user\n{prompt}\n model\n") } }; print!("{}", &prompt_str); diff --git a/candle-transformers/src/models/quantized_gemma3.rs b/candle-transformers/src/models/quantized_gemma3.rs index b5cbdf89..929f4936 100644 --- a/candle-transformers/src/models/quantized_gemma3.rs +++ b/candle-transformers/src/models/quantized_gemma3.rs @@ -241,12 +241,20 @@ impl ModelWeights { .and_then(|m| m.to_f32()) .unwrap_or(1000000f32); + let rope_freq_scaling_factor = md_get("gemma3.rope.scaling.factor") + .and_then(|m| m.to_f32()) + .unwrap_or(8f32); + // Compute the dimensions for queries, keys, and values // These are the total dimensions when projected across all heads let q_dim = head_count * key_length; // Precompute rotary embeddings - let (cos, sin) = precomput_freqs_cis(key_length, rope_freq_base, device)?; + let (cos, sin) = precomput_freqs_cis( + key_length, + rope_freq_base / rope_freq_scaling_factor, + device, + )?; let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; // Load token embeddings and output projection