Support more mistral models. (#1927)

* Support more mistral models.

* Use the appropriate rope parameter.
This commit is contained in:
Laurent Mazare
2024-03-24 08:04:04 +01:00
committed by GitHub
parent 5e70821dd0
commit e2b4829531
3 changed files with 70 additions and 26 deletions

View File

@ -21,11 +21,12 @@ fn rotate_half(xs: &Tensor) -> Result<Tensor> {
impl RotaryEmbedding {
fn new(cfg: &Config, dev: &Device) -> Result<Self> {
let rope_theta = cfg.rope_theta as f32;
let dim = cfg.hidden_size / cfg.num_attention_heads;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
.map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
@ -257,7 +258,7 @@ pub struct Model {
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Linear,
sliding_window: usize,
sliding_window: Option<usize>,
device: Device,
}
@ -290,11 +291,11 @@ impl Model {
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
// Sliding window mask?
let sliding_window = self.sliding_window.unwrap_or(tgt_len + 1);
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| {
(0..tgt_len).map(move |j| {
if i < j || j + self.sliding_window < i {
if i < j || j + sliding_window < i {
f32::NEG_INFINITY
} else {
0.