mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Rope fix.
This commit is contained in:
@ -39,7 +39,7 @@ pub struct LlamaConfig {
|
|||||||
max_seq_len: usize,
|
max_seq_len: usize,
|
||||||
intermediate_dim: usize,
|
intermediate_dim: usize,
|
||||||
norm_eps: f64,
|
norm_eps: f64,
|
||||||
rope_base: f64,
|
rope_base: f32,
|
||||||
scale_factor: usize,
|
scale_factor: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -80,24 +80,52 @@ struct RotaryEmbedding {
|
|||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn calculate_default_inv_freq(cfg: &LlamaConfig) -> Vec<f32> {
|
||||||
|
let head_dim = cfg.embed_dim / cfg.num_heads;
|
||||||
|
(0..head_dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / cfg.rope_base.powf(i as f32 / head_dim as f32))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
impl RotaryEmbedding {
|
impl RotaryEmbedding {
|
||||||
fn new(dtype: DType, cfg: &LlamaConfig, dev: &Device) -> Result<Self> {
|
fn new(dtype: DType, cfg: &LlamaConfig, dev: &Device) -> Result<Self> {
|
||||||
let dim = cfg.embed_dim / cfg.num_heads;
|
let low_freq_factor = 1.0;
|
||||||
let max_seq_len = cfg.max_seq_len;
|
let high_freq_factor = 4.0;
|
||||||
let inv_freq: Vec<_> = (0..dim)
|
let original_max_position_embeddings = 8192;
|
||||||
.step_by(2)
|
let scale_factor = cfg.scale_factor as f32;
|
||||||
.map(|i| 1f32 / cfg.rope_base.powf(i as f64 / dim as f64) as f32)
|
let theta = {
|
||||||
.collect();
|
let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
|
||||||
let inv_freq_len = inv_freq.len();
|
let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
|
||||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
|
||||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
calculate_default_inv_freq(cfg)
|
||||||
.to_dtype(dtype)?
|
.into_iter()
|
||||||
.reshape((max_seq_len, 1))?;
|
.map(|freq| {
|
||||||
let freqs = t.matmul(&inv_freq)?;
|
let wavelen = 2. * std::f32::consts::PI / freq;
|
||||||
Ok(Self {
|
if wavelen < high_freq_wavelen {
|
||||||
sin: freqs.sin()?,
|
freq
|
||||||
cos: freqs.cos()?,
|
} else if wavelen > low_freq_wavelen {
|
||||||
})
|
freq / scale_factor
|
||||||
|
} else {
|
||||||
|
let smooth = (original_max_position_embeddings as f32 / wavelen
|
||||||
|
- low_freq_factor)
|
||||||
|
/ (high_freq_factor - low_freq_factor);
|
||||||
|
(1. - smooth) * freq / scale_factor + smooth * freq
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
};
|
||||||
|
|
||||||
|
let theta = Tensor::new(theta, dev)?;
|
||||||
|
let idx_theta = Tensor::arange(0, cfg.max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.reshape((cfg.max_seq_len, 1))?
|
||||||
|
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||||
|
// This is different from the paper, see:
|
||||||
|
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||||
|
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
||||||
|
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
||||||
|
Ok(Self { cos, sin })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn apply_rotary_emb_qkv(
|
fn apply_rotary_emb_qkv(
|
||||||
@ -109,8 +137,8 @@ impl RotaryEmbedding {
|
|||||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||||
let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?;
|
let q_embed = candle_nn::rotary_emb::rope_i(q, &cos, &sin)?;
|
||||||
let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?;
|
let k_embed = candle_nn::rotary_emb::rope_i(k, &cos, &sin)?;
|
||||||
Ok((q_embed, k_embed))
|
Ok((q_embed, k_embed))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -347,7 +375,7 @@ impl LlamaModel {
|
|||||||
};
|
};
|
||||||
let mut xs = xs.clone();
|
let mut xs = xs.clone();
|
||||||
for layer in self.layers.iter_mut() {
|
for layer in self.layers.iter_mut() {
|
||||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?;
|
||||||
}
|
}
|
||||||
let ys = xs.narrow(1, seq_len - 1, 1)?.apply(&self.norm)?;
|
let ys = xs.narrow(1, seq_len - 1, 1)?.apply(&self.norm)?;
|
||||||
Ok(ys)
|
Ok(ys)
|
||||||
|
Reference in New Issue
Block a user