Rope fix.

This commit is contained in:
laurent
2025-04-03 18:01:25 +02:00
parent e319cd78d9
commit 825119ac4b

View File

@ -39,7 +39,7 @@ pub struct LlamaConfig {
max_seq_len: usize,
intermediate_dim: usize,
norm_eps: f64,
rope_base: f64,
rope_base: f32,
scale_factor: usize,
}
@ -80,24 +80,52 @@ struct RotaryEmbedding {
cos: Tensor,
}
fn calculate_default_inv_freq(cfg: &LlamaConfig) -> Vec<f32> {
let head_dim = cfg.embed_dim / cfg.num_heads;
(0..head_dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_base.powf(i as f32 / head_dim as f32))
.collect()
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &LlamaConfig, dev: &Device) -> Result<Self> {
let dim = cfg.embed_dim / cfg.num_heads;
let max_seq_len = cfg.max_seq_len;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_base.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
let low_freq_factor = 1.0;
let high_freq_factor = 4.0;
let original_max_position_embeddings = 8192;
let scale_factor = cfg.scale_factor as f32;
let theta = {
let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
calculate_default_inv_freq(cfg)
.into_iter()
.map(|freq| {
let wavelen = 2. * std::f32::consts::PI / freq;
if wavelen < high_freq_wavelen {
freq
} else if wavelen > low_freq_wavelen {
freq / scale_factor
} else {
let smooth = (original_max_position_embeddings as f32 / wavelen
- low_freq_factor)
/ (high_freq_factor - low_freq_factor);
(1. - smooth) * freq / scale_factor + smooth * freq
}
})
.collect::<Vec<_>>()
};
let theta = Tensor::new(theta, dev)?;
let idx_theta = Tensor::arange(0, cfg.max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((cfg.max_seq_len, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
Ok(Self { cos, sin })
}
fn apply_rotary_emb_qkv(
@ -109,8 +137,8 @@ impl RotaryEmbedding {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?;
let q_embed = candle_nn::rotary_emb::rope_i(q, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope_i(k, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
@ -347,7 +375,7 @@ impl LlamaModel {
};
let mut xs = xs.clone();
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?;
}
let ys = xs.narrow(1, seq_len - 1, 1)?.apply(&self.norm)?;
Ok(ys)