mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Support more mistral models. (#1927)
* Support more mistral models. * Use the appropriate rope parameter.
This commit is contained in:
@ -4,20 +4,25 @@ use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{Activation, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
fn default_use_flash_attn() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub(crate) vocab_size: usize,
|
||||
pub(crate) hidden_size: usize,
|
||||
pub(crate) intermediate_size: usize,
|
||||
pub(crate) num_hidden_layers: usize,
|
||||
pub(crate) num_attention_heads: usize,
|
||||
pub(crate) num_key_value_heads: usize,
|
||||
pub(crate) hidden_act: Activation,
|
||||
pub(crate) max_position_embeddings: usize,
|
||||
pub(crate) rms_norm_eps: f64,
|
||||
pub(crate) rope_theta: f64,
|
||||
pub(crate) sliding_window: usize,
|
||||
pub(crate) use_flash_attn: bool,
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub hidden_act: Activation,
|
||||
pub max_position_embeddings: usize,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f64,
|
||||
pub sliding_window: Option<usize>,
|
||||
#[serde(default = "default_use_flash_attn")]
|
||||
pub use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -34,7 +39,7 @@ impl Config {
|
||||
max_position_embeddings: 32768,
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 10_000.,
|
||||
sliding_window: 4096,
|
||||
sliding_window: Some(4096),
|
||||
use_flash_attn,
|
||||
}
|
||||
}
|
||||
@ -53,7 +58,7 @@ impl Config {
|
||||
max_position_embeddings: 32768,
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 10_000.,
|
||||
sliding_window: 4096,
|
||||
sliding_window: Some(4096),
|
||||
use_flash_attn,
|
||||
}
|
||||
}
|
||||
@ -71,7 +76,7 @@ impl Config {
|
||||
max_position_embeddings: 32768,
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 10_000.,
|
||||
sliding_window: 4096,
|
||||
sliding_window: Some(4096),
|
||||
use_flash_attn,
|
||||
}
|
||||
}
|
||||
@ -92,11 +97,12 @@ fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let rope_theta = cfg.rope_theta as f32;
|
||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
||||
.map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
@ -353,7 +359,7 @@ pub struct Model {
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
sliding_window: usize,
|
||||
sliding_window: Option<usize>,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
@ -388,11 +394,11 @@ impl Model {
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
// Sliding window mask?
|
||||
let sliding_window = self.sliding_window.unwrap_or(tgt_len + 1);
|
||||
let mask: Vec<_> = (0..tgt_len)
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + self.sliding_window < i {
|
||||
if i < j || j + sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
|
@ -21,11 +21,12 @@ fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let rope_theta = cfg.rope_theta as f32;
|
||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
||||
.map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
|
||||
@ -257,7 +258,7 @@ pub struct Model {
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
sliding_window: usize,
|
||||
sliding_window: Option<usize>,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
@ -290,11 +291,11 @@ impl Model {
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
// Sliding window mask?
|
||||
let sliding_window = self.sliding_window.unwrap_or(tgt_len + 1);
|
||||
let mask: Vec<_> = (0..tgt_len)
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + self.sliding_window < i {
|
||||
if i < j || j + sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
|
Reference in New Issue
Block a user