fixed quantized-gemma example (#2914)

* fixed quantized-gemma example

* lint
This commit is contained in:
Kyle Birnbaum
2025-04-22 20:39:03 -07:00
committed by GitHub
parent a4c56a958e
commit 99bd69f383
2 changed files with 10 additions and 2 deletions

View File

@ -224,7 +224,7 @@ fn main() -> anyhow::Result<()> {
}
}
// Format for Gemma 3 chat/instruction format
format!("<start_of_turn>user\n{prompt}\n<end_of_turn>\n<start_of_turn>model\n")
format!("<start_of_turn> user\n{prompt}<end_of_turn>\n<start_of_turn> model\n")
}
};
print!("{}", &prompt_str);

View File

@ -241,12 +241,20 @@ impl ModelWeights {
.and_then(|m| m.to_f32())
.unwrap_or(1000000f32);
let rope_freq_scaling_factor = md_get("gemma3.rope.scaling.factor")
.and_then(|m| m.to_f32())
.unwrap_or(8f32);
// Compute the dimensions for queries, keys, and values
// These are the total dimensions when projected across all heads
let q_dim = head_count * key_length;
// Precompute rotary embeddings
let (cos, sin) = precomput_freqs_cis(key_length, rope_freq_base, device)?;
let (cos, sin) = precomput_freqs_cis(
key_length,
rope_freq_base / rope_freq_scaling_factor,
device,
)?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
// Load token embeddings and output projection