From 825119ac4b53c66545c27c33ebc002f7ceb939c0 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 3 Apr 2025 18:01:25 +0200 Subject: [PATCH] Rope fix. --- candle-transformers/src/models/csm.rs | 68 +++++++++++++++++++-------- 1 file changed, 48 insertions(+), 20 deletions(-) diff --git a/candle-transformers/src/models/csm.rs b/candle-transformers/src/models/csm.rs index b8d3c2f1..02e95f99 100644 --- a/candle-transformers/src/models/csm.rs +++ b/candle-transformers/src/models/csm.rs @@ -39,7 +39,7 @@ pub struct LlamaConfig { max_seq_len: usize, intermediate_dim: usize, norm_eps: f64, - rope_base: f64, + rope_base: f32, scale_factor: usize, } @@ -80,24 +80,52 @@ struct RotaryEmbedding { cos: Tensor, } +fn calculate_default_inv_freq(cfg: &LlamaConfig) -> Vec { + let head_dim = cfg.embed_dim / cfg.num_heads; + (0..head_dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_base.powf(i as f32 / head_dim as f32)) + .collect() +} + impl RotaryEmbedding { fn new(dtype: DType, cfg: &LlamaConfig, dev: &Device) -> Result { - let dim = cfg.embed_dim / cfg.num_heads; - let max_seq_len = cfg.max_seq_len; - let inv_freq: Vec<_> = (0..dim) - .step_by(2) - .map(|i| 1f32 / cfg.rope_base.powf(i as f64 / dim as f64) as f32) - .collect(); - let inv_freq_len = inv_freq.len(); - let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; - let t = Tensor::arange(0u32, max_seq_len as u32, dev)? - .to_dtype(dtype)? - .reshape((max_seq_len, 1))?; - let freqs = t.matmul(&inv_freq)?; - Ok(Self { - sin: freqs.sin()?, - cos: freqs.cos()?, - }) + let low_freq_factor = 1.0; + let high_freq_factor = 4.0; + let original_max_position_embeddings = 8192; + let scale_factor = cfg.scale_factor as f32; + let theta = { + let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor; + let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor; + + calculate_default_inv_freq(cfg) + .into_iter() + .map(|freq| { + let wavelen = 2. * std::f32::consts::PI / freq; + if wavelen < high_freq_wavelen { + freq + } else if wavelen > low_freq_wavelen { + freq / scale_factor + } else { + let smooth = (original_max_position_embeddings as f32 / wavelen + - low_freq_factor) + / (high_freq_factor - low_freq_factor); + (1. - smooth) * freq / scale_factor + smooth * freq + } + }) + .collect::>() + }; + + let theta = Tensor::new(theta, dev)?; + let idx_theta = Tensor::arange(0, cfg.max_seq_len as u32, dev)? + .to_dtype(DType::F32)? + .reshape((cfg.max_seq_len, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + // This is different from the paper, see: + // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 + let cos = idx_theta.cos()?.to_dtype(dtype)?; + let sin = idx_theta.sin()?.to_dtype(dtype)?; + Ok(Self { cos, sin }) } fn apply_rotary_emb_qkv( @@ -109,8 +137,8 @@ impl RotaryEmbedding { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; - let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?; - let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?; + let q_embed = candle_nn::rotary_emb::rope_i(q, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope_i(k, &cos, &sin)?; Ok((q_embed, k_embed)) } } @@ -347,7 +375,7 @@ impl LlamaModel { }; let mut xs = xs.clone(); for layer in self.layers.iter_mut() { - xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? + xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?; } let ys = xs.narrow(1, seq_len - 1, 1)?.apply(&self.norm)?; Ok(ys)