mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Fixed Gemma3 model and example (#2917)
* gemma3: changed RotaryEmbedding base freq based on layer and sliding window * Changed attention mask per layer, either normal or sliding * made attention mask creation slightly more efficient by only creating them once per model iteration * changed is_sliding to an Option * clippy * changed to stop on both <eos> and <end_of_turn> instead of either or
This commit is contained in:
@ -21,6 +21,7 @@ pub struct Config {
|
||||
pub num_key_value_heads: usize,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f64,
|
||||
pub rope_local_base_freq: f64,
|
||||
pub vocab_size: usize,
|
||||
pub final_logit_softcapping: Option<f64>,
|
||||
pub attn_logit_softcapping: Option<f64>,
|
||||
@ -67,12 +68,22 @@ struct RotaryEmbedding {
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
fn new(
|
||||
dtype: DType,
|
||||
cfg: &Config,
|
||||
dev: &Device,
|
||||
sliding_window: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
let dim = cfg.head_dim;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let rope_freq = if sliding_window.is_some() {
|
||||
cfg.rope_local_base_freq
|
||||
} else {
|
||||
cfg.rope_theta
|
||||
};
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.map(|i| 1f32 / rope_freq.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
@ -162,8 +173,8 @@ impl Attention {
|
||||
fn new(
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
use_flash_attn: bool,
|
||||
is_sliding: bool,
|
||||
cfg: &Config,
|
||||
sliding_window: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
@ -178,13 +189,13 @@ impl Attention {
|
||||
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
|
||||
let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?;
|
||||
let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
|
||||
let kv_cache = if is_sliding {
|
||||
KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(
|
||||
2,
|
||||
cfg.sliding_window,
|
||||
))
|
||||
let kv_cache = if let Some(sliding_window) = sliding_window {
|
||||
KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(2, sliding_window))
|
||||
} else {
|
||||
KvCache::Normal(candle_nn::kv_cache::KvCache::new(2, cfg.sliding_window))
|
||||
KvCache::Normal(candle_nn::kv_cache::KvCache::new(
|
||||
2,
|
||||
cfg.max_position_embeddings,
|
||||
))
|
||||
};
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
@ -302,21 +313,27 @@ struct DecoderLayer {
|
||||
pre_feedforward_layernorm: RmsNorm,
|
||||
post_feedforward_layernorm: RmsNorm,
|
||||
post_attention_layernorm: RmsNorm,
|
||||
sliding_window: Option<usize>,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
use_flash_attn: bool,
|
||||
is_sliding: bool,
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
sliding_window: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(
|
||||
vb.dtype(),
|
||||
cfg,
|
||||
vb.device(),
|
||||
sliding_window,
|
||||
)?);
|
||||
let self_attn = Attention::new(
|
||||
rotary_emb,
|
||||
use_flash_attn,
|
||||
is_sliding,
|
||||
cfg,
|
||||
sliding_window,
|
||||
vb.pp("self_attn"),
|
||||
)?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
@ -344,6 +361,7 @@ impl DecoderLayer {
|
||||
pre_feedforward_layernorm,
|
||||
post_feedforward_layernorm,
|
||||
post_attention_layernorm,
|
||||
sliding_window,
|
||||
})
|
||||
}
|
||||
|
||||
@ -370,6 +388,42 @@ impl DecoderLayer {
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
sliding_window: Option<usize>,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = if let Some(sliding_window) = sliding_window {
|
||||
(0..tgt_len)
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
(0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0f32 }))
|
||||
.collect()
|
||||
};
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(dtype)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
@ -388,17 +442,15 @@ impl Model {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let is_sliding = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
|
||||
let sliding_window = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
|
||||
let layer = DecoderLayer::new(
|
||||
rotary_emb.clone(),
|
||||
use_flash_attn,
|
||||
is_sliding,
|
||||
cfg,
|
||||
vb_l.pp(layer_idx),
|
||||
sliding_window.then_some(cfg.sliding_window),
|
||||
)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
@ -417,51 +469,52 @@ impl Model {
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
fn create_attention_masks(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
batch_size: usize,
|
||||
seq_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = match Some(self.sliding_window) {
|
||||
None => (0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||
.collect(),
|
||||
Some(sliding_window) => (0..tgt_len)
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
};
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
) -> Result<(Option<Tensor>, Option<Tensor>)> {
|
||||
if seq_len <= 1 {
|
||||
return Ok((None, None));
|
||||
}
|
||||
|
||||
let mask = prepare_decoder_attention_mask(
|
||||
batch_size,
|
||||
seq_len,
|
||||
seqlen_offset,
|
||||
None,
|
||||
self.dtype,
|
||||
&self.device,
|
||||
)?;
|
||||
|
||||
let sliding_mask = prepare_decoder_attention_mask(
|
||||
batch_size,
|
||||
seq_len,
|
||||
seqlen_offset,
|
||||
Some(self.sliding_window),
|
||||
self.dtype,
|
||||
&self.device,
|
||||
)?;
|
||||
|
||||
Ok((Some(mask), Some(sliding_mask)))
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let xs = self.embed_tokens.forward(input_ids)?;
|
||||
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
|
||||
|
||||
let (attention_mask, sliding_attention_mask) =
|
||||
self.create_attention_masks(b_size, seq_len, seqlen_offset)?;
|
||||
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
let mask = if layer.sliding_window.is_some() {
|
||||
&sliding_attention_mask
|
||||
} else {
|
||||
&attention_mask
|
||||
};
|
||||
xs = layer.forward(&xs, mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
let logits = xs
|
||||
.narrow(1, seq_len - 1, 1)?
|
||||
|
Reference in New Issue
Block a user