Fixed Gemma3 model and example (#2917)

* gemma3: changed RotaryEmbedding base freq based on layer and sliding window

* Changed attention mask per layer, either normal or sliding

* made attention mask creation slightly more efficient by only creating them once per model iteration

* changed is_sliding to an Option

* clippy

* changed to stop on both <eos> and <end_of_turn> instead of either or
This commit is contained in:
Kyle Birnbaum
2025-04-24 20:35:08 -07:00
committed by GitHub
parent 82def7ae38
commit 6ff0a6999c
2 changed files with 143 additions and 54 deletions

View File

@ -21,6 +21,7 @@ pub struct Config {
pub num_key_value_heads: usize,
pub rms_norm_eps: f64,
pub rope_theta: f64,
pub rope_local_base_freq: f64,
pub vocab_size: usize,
pub final_logit_softcapping: Option<f64>,
pub attn_logit_softcapping: Option<f64>,
@ -67,12 +68,22 @@ struct RotaryEmbedding {
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
fn new(
dtype: DType,
cfg: &Config,
dev: &Device,
sliding_window: Option<usize>,
) -> Result<Self> {
let dim = cfg.head_dim;
let max_seq_len = cfg.max_position_embeddings;
let rope_freq = if sliding_window.is_some() {
cfg.rope_local_base_freq
} else {
cfg.rope_theta
};
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.map(|i| 1f32 / rope_freq.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
@ -162,8 +173,8 @@ impl Attention {
fn new(
rotary_emb: Arc<RotaryEmbedding>,
use_flash_attn: bool,
is_sliding: bool,
cfg: &Config,
sliding_window: Option<usize>,
vb: VarBuilder,
) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
@ -178,13 +189,13 @@ impl Attention {
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?;
let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
let kv_cache = if is_sliding {
KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(
2,
cfg.sliding_window,
))
let kv_cache = if let Some(sliding_window) = sliding_window {
KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(2, sliding_window))
} else {
KvCache::Normal(candle_nn::kv_cache::KvCache::new(2, cfg.sliding_window))
KvCache::Normal(candle_nn::kv_cache::KvCache::new(
2,
cfg.max_position_embeddings,
))
};
Ok(Self {
q_proj,
@ -302,21 +313,27 @@ struct DecoderLayer {
pre_feedforward_layernorm: RmsNorm,
post_feedforward_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
sliding_window: Option<usize>,
}
impl DecoderLayer {
fn new(
rotary_emb: Arc<RotaryEmbedding>,
use_flash_attn: bool,
is_sliding: bool,
cfg: &Config,
vb: VarBuilder,
sliding_window: Option<usize>,
) -> Result<Self> {
let rotary_emb = Arc::new(RotaryEmbedding::new(
vb.dtype(),
cfg,
vb.device(),
sliding_window,
)?);
let self_attn = Attention::new(
rotary_emb,
use_flash_attn,
is_sliding,
cfg,
sliding_window,
vb.pp("self_attn"),
)?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
@ -344,6 +361,7 @@ impl DecoderLayer {
pre_feedforward_layernorm,
post_feedforward_layernorm,
post_attention_layernorm,
sliding_window,
})
}
@ -370,6 +388,42 @@ impl DecoderLayer {
}
}
fn prepare_decoder_attention_mask(
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
sliding_window: Option<usize>,
dtype: DType,
device: &Device,
) -> Result<Tensor> {
let mask: Vec<_> = if let Some(sliding_window) = sliding_window {
(0..tgt_len)
.flat_map(|i| {
(0..tgt_len).map(move |j| {
if i < j || j + sliding_window < i {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect()
} else {
(0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0f32 }))
.collect()
};
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(dtype)
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
@ -388,17 +442,15 @@ impl Model {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let is_sliding = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
let sliding_window = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
let layer = DecoderLayer::new(
rotary_emb.clone(),
use_flash_attn,
is_sliding,
cfg,
vb_l.pp(layer_idx),
sliding_window.then_some(cfg.sliding_window),
)?;
layers.push(layer)
}
@ -417,51 +469,52 @@ impl Model {
})
}
fn prepare_decoder_attention_mask(
fn create_attention_masks(
&self,
b_size: usize,
tgt_len: usize,
batch_size: usize,
seq_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
let mask: Vec<_> = match Some(self.sliding_window) {
None => (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect(),
Some(sliding_window) => (0..tgt_len)
.flat_map(|i| {
(0..tgt_len).map(move |j| {
if i < j || j + sliding_window < i {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect(),
};
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
) -> Result<(Option<Tensor>, Option<Tensor>)> {
if seq_len <= 1 {
return Ok((None, None));
}
let mask = prepare_decoder_attention_mask(
batch_size,
seq_len,
seqlen_offset,
None,
self.dtype,
&self.device,
)?;
let sliding_mask = prepare_decoder_attention_mask(
batch_size,
seq_len,
seqlen_offset,
Some(self.sliding_window),
self.dtype,
&self.device,
)?;
Ok((Some(mask), Some(sliding_mask)))
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
};
let xs = self.embed_tokens.forward(input_ids)?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
let (attention_mask, sliding_attention_mask) =
self.create_attention_masks(b_size, seq_len, seqlen_offset)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
let mask = if layer.sliding_window.is_some() {
&sliding_attention_mask
} else {
&attention_mask
};
xs = layer.forward(&xs, mask.as_ref(), seqlen_offset)?
}
let logits = xs
.narrow(1, seq_len - 1, 1)?