From e2b4829531bb053c48e8124580695996b910ec00 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 24 Mar 2024 08:04:04 +0100 Subject: [PATCH] Support more mistral models. (#1927) * Support more mistral models. * Use the appropriate rope parameter. --- candle-examples/examples/mistral/main.rs | 41 ++++++++++++++++- candle-transformers/src/models/mistral.rs | 46 +++++++++++-------- .../src/models/quantized_mistral.rs | 9 ++-- 3 files changed, 70 insertions(+), 26 deletions(-) diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index 1cf4107c..a972279c 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -122,6 +122,18 @@ impl TextGeneration { } } +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +enum Which { + #[value(name = "7b-v0.1")] + Mistral7bV01, + #[value(name = "7b-v0.2")] + Mistral7bV02, + #[value(name = "7b-instruct-v0.1")] + Mistral7bInstructV01, + #[value(name = "7b-instruct-v0.2")] + Mistral7bInstructV02, +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -155,6 +167,10 @@ struct Args { #[arg(long, short = 'n', default_value_t = 10000)] sample_len: usize, + /// The model size to use. + #[arg(long, default_value = "7b-v0.1")] + which: Which, + #[arg(long)] model_id: Option, @@ -164,6 +180,9 @@ struct Args { #[arg(long)] tokenizer_file: Option, + #[arg(long)] + config_file: Option, + #[arg(long)] weight_files: Option, @@ -211,9 +230,17 @@ fn main() -> Result<()> { Some(model_id) => model_id, None => { if args.quantized { + if args.which != Which::Mistral7bV01 { + anyhow::bail!("only 7b-v0.1 is available as a quantized model for now") + } "lmz/candle-mistral".to_string() } else { - "mistralai/Mistral-7B-v0.1".to_string() + match args.which { + Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1".to_string(), + Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2".to_string(), + Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1".to_string(), + Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2".to_string(), + } } } }; @@ -243,7 +270,17 @@ fn main() -> Result<()> { let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); - let config = Config::config_7b_v0_1(args.use_flash_attn); + let config = match args.config_file { + Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?, + None => { + if args.quantized { + Config::config_7b_v0_1(args.use_flash_attn) + } else { + let config_file = repo.get("config.json")?; + serde_json::from_slice(&std::fs::read(config_file)?)? + } + } + }; let device = candle_examples::device(args.cpu)?; let (model, device) = if args.quantized { let filename = &filenames[0]; diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index e40ae3ad..0e6200f5 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -4,20 +4,25 @@ use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; -#[derive(Debug, Clone, PartialEq)] +fn default_use_flash_attn() -> bool { + false +} + +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] pub struct Config { - pub(crate) vocab_size: usize, - pub(crate) hidden_size: usize, - pub(crate) intermediate_size: usize, - pub(crate) num_hidden_layers: usize, - pub(crate) num_attention_heads: usize, - pub(crate) num_key_value_heads: usize, - pub(crate) hidden_act: Activation, - pub(crate) max_position_embeddings: usize, - pub(crate) rms_norm_eps: f64, - pub(crate) rope_theta: f64, - pub(crate) sliding_window: usize, - pub(crate) use_flash_attn: bool, + pub vocab_size: usize, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub hidden_act: Activation, + pub max_position_embeddings: usize, + pub rms_norm_eps: f64, + pub rope_theta: f64, + pub sliding_window: Option, + #[serde(default = "default_use_flash_attn")] + pub use_flash_attn: bool, } impl Config { @@ -34,7 +39,7 @@ impl Config { max_position_embeddings: 32768, rms_norm_eps: 1e-5, rope_theta: 10_000., - sliding_window: 4096, + sliding_window: Some(4096), use_flash_attn, } } @@ -53,7 +58,7 @@ impl Config { max_position_embeddings: 32768, rms_norm_eps: 1e-5, rope_theta: 10_000., - sliding_window: 4096, + sliding_window: Some(4096), use_flash_attn, } } @@ -71,7 +76,7 @@ impl Config { max_position_embeddings: 32768, rms_norm_eps: 1e-5, rope_theta: 10_000., - sliding_window: 4096, + sliding_window: Some(4096), use_flash_attn, } } @@ -92,11 +97,12 @@ fn rotate_half(xs: &Tensor) -> Result { impl RotaryEmbedding { fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + let rope_theta = cfg.rope_theta as f32; let dim = cfg.hidden_size / cfg.num_attention_heads; let max_seq_len = cfg.max_position_embeddings; let inv_freq: Vec<_> = (0..dim) .step_by(2) - .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) + .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32)) .collect(); let inv_freq_len = inv_freq.len(); let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; @@ -353,7 +359,7 @@ pub struct Model { layers: Vec, norm: RmsNorm, lm_head: Linear, - sliding_window: usize, + sliding_window: Option, device: Device, dtype: DType, } @@ -388,11 +394,11 @@ impl Model { tgt_len: usize, seqlen_offset: usize, ) -> Result { - // Sliding window mask? + let sliding_window = self.sliding_window.unwrap_or(tgt_len + 1); let mask: Vec<_> = (0..tgt_len) .flat_map(|i| { (0..tgt_len).map(move |j| { - if i < j || j + self.sliding_window < i { + if i < j || j + sliding_window < i { f32::NEG_INFINITY } else { 0. diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs index 5f026f2b..2c5b7f74 100644 --- a/candle-transformers/src/models/quantized_mistral.rs +++ b/candle-transformers/src/models/quantized_mistral.rs @@ -21,11 +21,12 @@ fn rotate_half(xs: &Tensor) -> Result { impl RotaryEmbedding { fn new(cfg: &Config, dev: &Device) -> Result { + let rope_theta = cfg.rope_theta as f32; let dim = cfg.hidden_size / cfg.num_attention_heads; let max_seq_len = cfg.max_position_embeddings; let inv_freq: Vec<_> = (0..dim) .step_by(2) - .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) + .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32)) .collect(); let inv_freq_len = inv_freq.len(); let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; @@ -257,7 +258,7 @@ pub struct Model { layers: Vec, norm: RmsNorm, lm_head: Linear, - sliding_window: usize, + sliding_window: Option, device: Device, } @@ -290,11 +291,11 @@ impl Model { tgt_len: usize, seqlen_offset: usize, ) -> Result { - // Sliding window mask? + let sliding_window = self.sliding_window.unwrap_or(tgt_len + 1); let mask: Vec<_> = (0..tgt_len) .flat_map(|i| { (0..tgt_len).map(move |j| { - if i < j || j + self.sliding_window < i { + if i < j || j + sliding_window < i { f32::NEG_INFINITY } else { 0.