Support more mistral models. (#1927)

* Support more mistral models.

* Use the appropriate rope parameter.
This commit is contained in:
Laurent Mazare
2024-03-24 08:04:04 +01:00
committed by GitHub
parent 5e70821dd0
commit e2b4829531
3 changed files with 70 additions and 26 deletions

View File

@ -122,6 +122,18 @@ impl TextGeneration {
} }
} }
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "7b-v0.1")]
Mistral7bV01,
#[value(name = "7b-v0.2")]
Mistral7bV02,
#[value(name = "7b-instruct-v0.1")]
Mistral7bInstructV01,
#[value(name = "7b-instruct-v0.2")]
Mistral7bInstructV02,
}
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
struct Args { struct Args {
@ -155,6 +167,10 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 10000)] #[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize, sample_len: usize,
/// The model size to use.
#[arg(long, default_value = "7b-v0.1")]
which: Which,
#[arg(long)] #[arg(long)]
model_id: Option<String>, model_id: Option<String>,
@ -164,6 +180,9 @@ struct Args {
#[arg(long)] #[arg(long)]
tokenizer_file: Option<String>, tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long)] #[arg(long)]
weight_files: Option<String>, weight_files: Option<String>,
@ -211,9 +230,17 @@ fn main() -> Result<()> {
Some(model_id) => model_id, Some(model_id) => model_id,
None => { None => {
if args.quantized { if args.quantized {
if args.which != Which::Mistral7bV01 {
anyhow::bail!("only 7b-v0.1 is available as a quantized model for now")
}
"lmz/candle-mistral".to_string() "lmz/candle-mistral".to_string()
} else { } else {
"mistralai/Mistral-7B-v0.1".to_string() match args.which {
Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1".to_string(),
Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2".to_string(),
Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1".to_string(),
Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2".to_string(),
}
} }
} }
}; };
@ -243,7 +270,17 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let config = Config::config_7b_v0_1(args.use_flash_attn); let config = match args.config_file {
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
None => {
if args.quantized {
Config::config_7b_v0_1(args.use_flash_attn)
} else {
let config_file = repo.get("config.json")?;
serde_json::from_slice(&std::fs::read(config_file)?)?
}
}
};
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized { let (model, device) = if args.quantized {
let filename = &filenames[0]; let filename = &filenames[0];

View File

@ -4,20 +4,25 @@ use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder}; use candle_nn::{Activation, VarBuilder};
use std::sync::Arc; use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)] fn default_use_flash_attn() -> bool {
false
}
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Config { pub struct Config {
pub(crate) vocab_size: usize, pub vocab_size: usize,
pub(crate) hidden_size: usize, pub hidden_size: usize,
pub(crate) intermediate_size: usize, pub intermediate_size: usize,
pub(crate) num_hidden_layers: usize, pub num_hidden_layers: usize,
pub(crate) num_attention_heads: usize, pub num_attention_heads: usize,
pub(crate) num_key_value_heads: usize, pub num_key_value_heads: usize,
pub(crate) hidden_act: Activation, pub hidden_act: Activation,
pub(crate) max_position_embeddings: usize, pub max_position_embeddings: usize,
pub(crate) rms_norm_eps: f64, pub rms_norm_eps: f64,
pub(crate) rope_theta: f64, pub rope_theta: f64,
pub(crate) sliding_window: usize, pub sliding_window: Option<usize>,
pub(crate) use_flash_attn: bool, #[serde(default = "default_use_flash_attn")]
pub use_flash_attn: bool,
} }
impl Config { impl Config {
@ -34,7 +39,7 @@ impl Config {
max_position_embeddings: 32768, max_position_embeddings: 32768,
rms_norm_eps: 1e-5, rms_norm_eps: 1e-5,
rope_theta: 10_000., rope_theta: 10_000.,
sliding_window: 4096, sliding_window: Some(4096),
use_flash_attn, use_flash_attn,
} }
} }
@ -53,7 +58,7 @@ impl Config {
max_position_embeddings: 32768, max_position_embeddings: 32768,
rms_norm_eps: 1e-5, rms_norm_eps: 1e-5,
rope_theta: 10_000., rope_theta: 10_000.,
sliding_window: 4096, sliding_window: Some(4096),
use_flash_attn, use_flash_attn,
} }
} }
@ -71,7 +76,7 @@ impl Config {
max_position_embeddings: 32768, max_position_embeddings: 32768,
rms_norm_eps: 1e-5, rms_norm_eps: 1e-5,
rope_theta: 10_000., rope_theta: 10_000.,
sliding_window: 4096, sliding_window: Some(4096),
use_flash_attn, use_flash_attn,
} }
} }
@ -92,11 +97,12 @@ fn rotate_half(xs: &Tensor) -> Result<Tensor> {
impl RotaryEmbedding { impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> { fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let rope_theta = cfg.rope_theta as f32;
let dim = cfg.hidden_size / cfg.num_attention_heads; let dim = cfg.hidden_size / cfg.num_attention_heads;
let max_seq_len = cfg.max_position_embeddings; let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim) let inv_freq: Vec<_> = (0..dim)
.step_by(2) .step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
.collect(); .collect();
let inv_freq_len = inv_freq.len(); let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
@ -353,7 +359,7 @@ pub struct Model {
layers: Vec<DecoderLayer>, layers: Vec<DecoderLayer>,
norm: RmsNorm, norm: RmsNorm,
lm_head: Linear, lm_head: Linear,
sliding_window: usize, sliding_window: Option<usize>,
device: Device, device: Device,
dtype: DType, dtype: DType,
} }
@ -388,11 +394,11 @@ impl Model {
tgt_len: usize, tgt_len: usize,
seqlen_offset: usize, seqlen_offset: usize,
) -> Result<Tensor> { ) -> Result<Tensor> {
// Sliding window mask? let sliding_window = self.sliding_window.unwrap_or(tgt_len + 1);
let mask: Vec<_> = (0..tgt_len) let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| { .flat_map(|i| {
(0..tgt_len).map(move |j| { (0..tgt_len).map(move |j| {
if i < j || j + self.sliding_window < i { if i < j || j + sliding_window < i {
f32::NEG_INFINITY f32::NEG_INFINITY
} else { } else {
0. 0.

View File

@ -21,11 +21,12 @@ fn rotate_half(xs: &Tensor) -> Result<Tensor> {
impl RotaryEmbedding { impl RotaryEmbedding {
fn new(cfg: &Config, dev: &Device) -> Result<Self> { fn new(cfg: &Config, dev: &Device) -> Result<Self> {
let rope_theta = cfg.rope_theta as f32;
let dim = cfg.hidden_size / cfg.num_attention_heads; let dim = cfg.hidden_size / cfg.num_attention_heads;
let max_seq_len = cfg.max_position_embeddings; let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim) let inv_freq: Vec<_> = (0..dim)
.step_by(2) .step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
.collect(); .collect();
let inv_freq_len = inv_freq.len(); let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
@ -257,7 +258,7 @@ pub struct Model {
layers: Vec<DecoderLayer>, layers: Vec<DecoderLayer>,
norm: RmsNorm, norm: RmsNorm,
lm_head: Linear, lm_head: Linear,
sliding_window: usize, sliding_window: Option<usize>,
device: Device, device: Device,
} }
@ -290,11 +291,11 @@ impl Model {
tgt_len: usize, tgt_len: usize,
seqlen_offset: usize, seqlen_offset: usize,
) -> Result<Tensor> { ) -> Result<Tensor> {
// Sliding window mask? let sliding_window = self.sliding_window.unwrap_or(tgt_len + 1);
let mask: Vec<_> = (0..tgt_len) let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| { .flat_map(|i| {
(0..tgt_len).map(move |j| { (0..tgt_len).map(move |j| {
if i < j || j + self.sliding_window < i { if i < j || j + sliding_window < i {
f32::NEG_INFINITY f32::NEG_INFINITY
} else { } else {
0. 0.