From 6ff0a6999cd9fb05411fd07f14803a658f393dca Mon Sep 17 00:00:00 2001 From: Kyle Birnbaum Date: Thu, 24 Apr 2025 20:35:08 -0700 Subject: [PATCH] Fixed Gemma3 model and example (#2917) * gemma3: changed RotaryEmbedding base freq based on layer and sliding window * Changed attention mask per layer, either normal or sliding * made attention mask creation slightly more efficient by only creating them once per model iteration * changed is_sliding to an Option * clippy * changed to stop on both and instead of either or --- candle-examples/examples/gemma/main.rs | 40 +++++- candle-transformers/src/models/gemma3.rs | 157 +++++++++++++++-------- 2 files changed, 143 insertions(+), 54 deletions(-) diff --git a/candle-examples/examples/gemma/main.rs b/candle-examples/examples/gemma/main.rs index f6247c02..81167ac2 100644 --- a/candle-examples/examples/gemma/main.rs +++ b/candle-examples/examples/gemma/main.rs @@ -124,6 +124,17 @@ impl TextGeneration { Some(token) => token, None => anyhow::bail!("cannot find the token"), }; + + let eot_token = match self.tokenizer.get_token("") { + Some(token) => token, + None => { + println!( + "Warning: token not found in tokenizer, using as a backup" + ); + eos_token + } + }; + let start_gen = std::time::Instant::now(); for index in 0..sample_len { let context_size = if index > 0 { 1 } else { tokens.len() }; @@ -146,7 +157,7 @@ impl TextGeneration { let next_token = self.logits_processor.sample(&logits)?; tokens.push(next_token); generated_tokens += 1; - if next_token == eos_token { + if next_token == eos_token || next_token == eot_token { break; } if let Some(t) = self.tokenizer.next_token(next_token)? { @@ -350,6 +361,31 @@ fn main() -> Result<()> { args.repeat_last_n, &device, ); - pipeline.run(&args.prompt, args.sample_len)?; + + let prompt = match args.which { + Which::Base2B + | Which::Base7B + | Which::Instruct2B + | Which::Instruct7B + | Which::InstructV1_1_2B + | Which::InstructV1_1_7B + | Which::CodeBase2B + | Which::CodeBase7B + | Which::CodeInstruct2B + | Which::CodeInstruct7B + | Which::BaseV2_2B + | Which::InstructV2_2B + | Which::BaseV2_9B + | Which::InstructV2_9B + | Which::BaseV3_1B => args.prompt, + Which::InstructV3_1B => { + format!( + " user\n{}\n model\n", + args.prompt + ) + } + }; + + pipeline.run(&prompt, args.sample_len)?; Ok(()) } diff --git a/candle-transformers/src/models/gemma3.rs b/candle-transformers/src/models/gemma3.rs index 7d5e520b..08b4e5ad 100644 --- a/candle-transformers/src/models/gemma3.rs +++ b/candle-transformers/src/models/gemma3.rs @@ -21,6 +21,7 @@ pub struct Config { pub num_key_value_heads: usize, pub rms_norm_eps: f64, pub rope_theta: f64, + pub rope_local_base_freq: f64, pub vocab_size: usize, pub final_logit_softcapping: Option, pub attn_logit_softcapping: Option, @@ -67,12 +68,22 @@ struct RotaryEmbedding { } impl RotaryEmbedding { - fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + fn new( + dtype: DType, + cfg: &Config, + dev: &Device, + sliding_window: Option, + ) -> Result { let dim = cfg.head_dim; let max_seq_len = cfg.max_position_embeddings; + let rope_freq = if sliding_window.is_some() { + cfg.rope_local_base_freq + } else { + cfg.rope_theta + }; let inv_freq: Vec<_> = (0..dim) .step_by(2) - .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .map(|i| 1f32 / rope_freq.powf(i as f64 / dim as f64) as f32) .collect(); let inv_freq_len = inv_freq.len(); let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; @@ -162,8 +173,8 @@ impl Attention { fn new( rotary_emb: Arc, use_flash_attn: bool, - is_sliding: bool, cfg: &Config, + sliding_window: Option, vb: VarBuilder, ) -> Result { let hidden_sz = cfg.hidden_size; @@ -178,13 +189,13 @@ impl Attention { let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?; let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?; let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?; - let kv_cache = if is_sliding { - KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new( - 2, - cfg.sliding_window, - )) + let kv_cache = if let Some(sliding_window) = sliding_window { + KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(2, sliding_window)) } else { - KvCache::Normal(candle_nn::kv_cache::KvCache::new(2, cfg.sliding_window)) + KvCache::Normal(candle_nn::kv_cache::KvCache::new( + 2, + cfg.max_position_embeddings, + )) }; Ok(Self { q_proj, @@ -302,21 +313,27 @@ struct DecoderLayer { pre_feedforward_layernorm: RmsNorm, post_feedforward_layernorm: RmsNorm, post_attention_layernorm: RmsNorm, + sliding_window: Option, } impl DecoderLayer { fn new( - rotary_emb: Arc, use_flash_attn: bool, - is_sliding: bool, cfg: &Config, vb: VarBuilder, + sliding_window: Option, ) -> Result { + let rotary_emb = Arc::new(RotaryEmbedding::new( + vb.dtype(), + cfg, + vb.device(), + sliding_window, + )?); let self_attn = Attention::new( rotary_emb, use_flash_attn, - is_sliding, cfg, + sliding_window, vb.pp("self_attn"), )?; let mlp = MLP::new(cfg, vb.pp("mlp"))?; @@ -344,6 +361,7 @@ impl DecoderLayer { pre_feedforward_layernorm, post_feedforward_layernorm, post_attention_layernorm, + sliding_window, }) } @@ -370,6 +388,42 @@ impl DecoderLayer { } } +fn prepare_decoder_attention_mask( + b_size: usize, + tgt_len: usize, + seqlen_offset: usize, + sliding_window: Option, + dtype: DType, + device: &Device, +) -> Result { + let mask: Vec<_> = if let Some(sliding_window) = sliding_window { + (0..tgt_len) + .flat_map(|i| { + (0..tgt_len).map(move |j| { + if i < j || j + sliding_window < i { + f32::NEG_INFINITY + } else { + 0. + } + }) + }) + .collect() + } else { + (0..tgt_len) + .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0f32 })) + .collect() + }; + let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), device)?; + let mask = if seqlen_offset > 0 { + let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, device)?; + Tensor::cat(&[&mask0, &mask], D::Minus1)? + } else { + mask + }; + mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + .to_dtype(dtype) +} + #[derive(Debug, Clone)] pub struct Model { embed_tokens: candle_nn::Embedding, @@ -388,17 +442,15 @@ impl Model { let vb_m = vb.pp("model"); let embed_tokens = candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; - let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?); let mut layers = Vec::with_capacity(cfg.num_hidden_layers); let vb_l = vb_m.pp("layers"); for layer_idx in 0..cfg.num_hidden_layers { - let is_sliding = (layer_idx + 1) % cfg.sliding_window_pattern > 0; + let sliding_window = (layer_idx + 1) % cfg.sliding_window_pattern > 0; let layer = DecoderLayer::new( - rotary_emb.clone(), use_flash_attn, - is_sliding, cfg, vb_l.pp(layer_idx), + sliding_window.then_some(cfg.sliding_window), )?; layers.push(layer) } @@ -417,51 +469,52 @@ impl Model { }) } - fn prepare_decoder_attention_mask( + fn create_attention_masks( &self, - b_size: usize, - tgt_len: usize, + batch_size: usize, + seq_len: usize, seqlen_offset: usize, - ) -> Result { - let mask: Vec<_> = match Some(self.sliding_window) { - None => (0..tgt_len) - .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) - .collect(), - Some(sliding_window) => (0..tgt_len) - .flat_map(|i| { - (0..tgt_len).map(move |j| { - if i < j || j + sliding_window < i { - f32::NEG_INFINITY - } else { - 0. - } - }) - }) - .collect(), - }; - let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; - let mask = if seqlen_offset > 0 { - let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?; - Tensor::cat(&[&mask0, &mask], D::Minus1)? - } else { - mask - }; - mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? - .to_dtype(self.dtype) + ) -> Result<(Option, Option)> { + if seq_len <= 1 { + return Ok((None, None)); + } + + let mask = prepare_decoder_attention_mask( + batch_size, + seq_len, + seqlen_offset, + None, + self.dtype, + &self.device, + )?; + + let sliding_mask = prepare_decoder_attention_mask( + batch_size, + seq_len, + seqlen_offset, + Some(self.sliding_window), + self.dtype, + &self.device, + )?; + + Ok((Some(mask), Some(sliding_mask))) } pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result { let (b_size, seq_len) = input_ids.dims2()?; - let attention_mask = if seq_len <= 1 { - None - } else { - let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; - Some(mask) - }; let xs = self.embed_tokens.forward(input_ids)?; let mut xs = (xs * (self.hidden_size as f64).sqrt())?; + + let (attention_mask, sliding_attention_mask) = + self.create_attention_masks(b_size, seq_len, seqlen_offset)?; + for layer in self.layers.iter_mut() { - xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? + let mask = if layer.sliding_window.is_some() { + &sliding_attention_mask + } else { + &attention_mask + }; + xs = layer.forward(&xs, mask.as_ref(), seqlen_offset)? } let logits = xs .narrow(1, seq_len - 1, 1)?