mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
fixed quantized-gemma example (#2914)
* fixed quantized-gemma example * lint
This commit is contained in:
@ -224,7 +224,7 @@ fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
// Format for Gemma 3 chat/instruction format
|
||||
format!("<start_of_turn>user\n{prompt}\n<end_of_turn>\n<start_of_turn>model\n")
|
||||
format!("<start_of_turn> user\n{prompt}<end_of_turn>\n<start_of_turn> model\n")
|
||||
}
|
||||
};
|
||||
print!("{}", &prompt_str);
|
||||
|
@ -241,12 +241,20 @@ impl ModelWeights {
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(1000000f32);
|
||||
|
||||
let rope_freq_scaling_factor = md_get("gemma3.rope.scaling.factor")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(8f32);
|
||||
|
||||
// Compute the dimensions for queries, keys, and values
|
||||
// These are the total dimensions when projected across all heads
|
||||
let q_dim = head_count * key_length;
|
||||
|
||||
// Precompute rotary embeddings
|
||||
let (cos, sin) = precomput_freqs_cis(key_length, rope_freq_base, device)?;
|
||||
let (cos, sin) = precomput_freqs_cis(
|
||||
key_length,
|
||||
rope_freq_base / rope_freq_scaling_factor,
|
||||
device,
|
||||
)?;
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||
|
||||
// Load token embeddings and output projection
|
||||
|
Reference in New Issue
Block a user