fixed quantized-gemma example (#2914)

* fixed quantized-gemma example

* lint
This commit is contained in:
Kyle Birnbaum
2025-04-22 20:39:03 -07:00
committed by GitHub
parent a4c56a958e
commit 99bd69f383
2 changed files with 10 additions and 2 deletions

View File

@ -241,12 +241,20 @@ impl ModelWeights {
.and_then(|m| m.to_f32())
.unwrap_or(1000000f32);
let rope_freq_scaling_factor = md_get("gemma3.rope.scaling.factor")
.and_then(|m| m.to_f32())
.unwrap_or(8f32);
// Compute the dimensions for queries, keys, and values
// These are the total dimensions when projected across all heads
let q_dim = head_count * key_length;
// Precompute rotary embeddings
let (cos, sin) = precomput_freqs_cis(key_length, rope_freq_base, device)?;
let (cos, sin) = precomput_freqs_cis(
key_length,
rope_freq_base / rope_freq_scaling_factor,
device,
)?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
// Load token embeddings and output projection