mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Rope fix.
This commit is contained in:
@ -39,7 +39,7 @@ pub struct LlamaConfig {
|
||||
max_seq_len: usize,
|
||||
intermediate_dim: usize,
|
||||
norm_eps: f64,
|
||||
rope_base: f64,
|
||||
rope_base: f32,
|
||||
scale_factor: usize,
|
||||
}
|
||||
|
||||
@ -80,24 +80,52 @@ struct RotaryEmbedding {
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
fn calculate_default_inv_freq(cfg: &LlamaConfig) -> Vec<f32> {
|
||||
let head_dim = cfg.embed_dim / cfg.num_heads;
|
||||
(0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_base.powf(i as f32 / head_dim as f32))
|
||||
.collect()
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &LlamaConfig, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.embed_dim / cfg.num_heads;
|
||||
let max_seq_len = cfg.max_seq_len;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_base.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
let low_freq_factor = 1.0;
|
||||
let high_freq_factor = 4.0;
|
||||
let original_max_position_embeddings = 8192;
|
||||
let scale_factor = cfg.scale_factor as f32;
|
||||
let theta = {
|
||||
let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
|
||||
let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
|
||||
|
||||
calculate_default_inv_freq(cfg)
|
||||
.into_iter()
|
||||
.map(|freq| {
|
||||
let wavelen = 2. * std::f32::consts::PI / freq;
|
||||
if wavelen < high_freq_wavelen {
|
||||
freq
|
||||
} else if wavelen > low_freq_wavelen {
|
||||
freq / scale_factor
|
||||
} else {
|
||||
let smooth = (original_max_position_embeddings as f32 / wavelen
|
||||
- low_freq_factor)
|
||||
/ (high_freq_factor - low_freq_factor);
|
||||
(1. - smooth) * freq / scale_factor + smooth * freq
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
let theta = Tensor::new(theta, dev)?;
|
||||
let idx_theta = Tensor::arange(0, cfg.max_seq_len as u32, dev)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((cfg.max_seq_len, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
// This is different from the paper, see:
|
||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
||||
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
||||
Ok(Self { cos, sin })
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
@ -109,8 +137,8 @@ impl RotaryEmbedding {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope_i(q, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope_i(k, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
@ -347,7 +375,7 @@ impl LlamaModel {
|
||||
};
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?;
|
||||
}
|
||||
let ys = xs.narrow(1, seq_len - 1, 1)?.apply(&self.norm)?;
|
||||
Ok(ys)
|
||||
|
Reference in New Issue
Block a user