mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Support more mistral models. (#1927)
* Support more mistral models. * Use the appropriate rope parameter.
This commit is contained in:
@ -122,6 +122,18 @@ impl TextGeneration {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "7b-v0.1")]
|
||||||
|
Mistral7bV01,
|
||||||
|
#[value(name = "7b-v0.2")]
|
||||||
|
Mistral7bV02,
|
||||||
|
#[value(name = "7b-instruct-v0.1")]
|
||||||
|
Mistral7bInstructV01,
|
||||||
|
#[value(name = "7b-instruct-v0.2")]
|
||||||
|
Mistral7bInstructV02,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -155,6 +167,10 @@ struct Args {
|
|||||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
|
/// The model size to use.
|
||||||
|
#[arg(long, default_value = "7b-v0.1")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
|
|
||||||
@ -164,6 +180,9 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tokenizer_file: Option<String>,
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config_file: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
weight_files: Option<String>,
|
weight_files: Option<String>,
|
||||||
|
|
||||||
@ -211,9 +230,17 @@ fn main() -> Result<()> {
|
|||||||
Some(model_id) => model_id,
|
Some(model_id) => model_id,
|
||||||
None => {
|
None => {
|
||||||
if args.quantized {
|
if args.quantized {
|
||||||
|
if args.which != Which::Mistral7bV01 {
|
||||||
|
anyhow::bail!("only 7b-v0.1 is available as a quantized model for now")
|
||||||
|
}
|
||||||
"lmz/candle-mistral".to_string()
|
"lmz/candle-mistral".to_string()
|
||||||
} else {
|
} else {
|
||||||
"mistralai/Mistral-7B-v0.1".to_string()
|
match args.which {
|
||||||
|
Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1".to_string(),
|
||||||
|
Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2".to_string(),
|
||||||
|
Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1".to_string(),
|
||||||
|
Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2".to_string(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -243,7 +270,17 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let config = Config::config_7b_v0_1(args.use_flash_attn);
|
let config = match args.config_file {
|
||||||
|
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
|
||||||
|
None => {
|
||||||
|
if args.quantized {
|
||||||
|
Config::config_7b_v0_1(args.use_flash_attn)
|
||||||
|
} else {
|
||||||
|
let config_file = repo.get("config.json")?;
|
||||||
|
serde_json::from_slice(&std::fs::read(config_file)?)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let (model, device) = if args.quantized {
|
let (model, device) = if args.quantized {
|
||||||
let filename = &filenames[0];
|
let filename = &filenames[0];
|
||||||
|
@ -4,20 +4,25 @@ use candle::{DType, Device, Module, Result, Tensor, D};
|
|||||||
use candle_nn::{Activation, VarBuilder};
|
use candle_nn::{Activation, VarBuilder};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
fn default_use_flash_attn() -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub(crate) vocab_size: usize,
|
pub vocab_size: usize,
|
||||||
pub(crate) hidden_size: usize,
|
pub hidden_size: usize,
|
||||||
pub(crate) intermediate_size: usize,
|
pub intermediate_size: usize,
|
||||||
pub(crate) num_hidden_layers: usize,
|
pub num_hidden_layers: usize,
|
||||||
pub(crate) num_attention_heads: usize,
|
pub num_attention_heads: usize,
|
||||||
pub(crate) num_key_value_heads: usize,
|
pub num_key_value_heads: usize,
|
||||||
pub(crate) hidden_act: Activation,
|
pub hidden_act: Activation,
|
||||||
pub(crate) max_position_embeddings: usize,
|
pub max_position_embeddings: usize,
|
||||||
pub(crate) rms_norm_eps: f64,
|
pub rms_norm_eps: f64,
|
||||||
pub(crate) rope_theta: f64,
|
pub rope_theta: f64,
|
||||||
pub(crate) sliding_window: usize,
|
pub sliding_window: Option<usize>,
|
||||||
pub(crate) use_flash_attn: bool,
|
#[serde(default = "default_use_flash_attn")]
|
||||||
|
pub use_flash_attn: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
@ -34,7 +39,7 @@ impl Config {
|
|||||||
max_position_embeddings: 32768,
|
max_position_embeddings: 32768,
|
||||||
rms_norm_eps: 1e-5,
|
rms_norm_eps: 1e-5,
|
||||||
rope_theta: 10_000.,
|
rope_theta: 10_000.,
|
||||||
sliding_window: 4096,
|
sliding_window: Some(4096),
|
||||||
use_flash_attn,
|
use_flash_attn,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -53,7 +58,7 @@ impl Config {
|
|||||||
max_position_embeddings: 32768,
|
max_position_embeddings: 32768,
|
||||||
rms_norm_eps: 1e-5,
|
rms_norm_eps: 1e-5,
|
||||||
rope_theta: 10_000.,
|
rope_theta: 10_000.,
|
||||||
sliding_window: 4096,
|
sliding_window: Some(4096),
|
||||||
use_flash_attn,
|
use_flash_attn,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -71,7 +76,7 @@ impl Config {
|
|||||||
max_position_embeddings: 32768,
|
max_position_embeddings: 32768,
|
||||||
rms_norm_eps: 1e-5,
|
rms_norm_eps: 1e-5,
|
||||||
rope_theta: 10_000.,
|
rope_theta: 10_000.,
|
||||||
sliding_window: 4096,
|
sliding_window: Some(4096),
|
||||||
use_flash_attn,
|
use_flash_attn,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -92,11 +97,12 @@ fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
|||||||
|
|
||||||
impl RotaryEmbedding {
|
impl RotaryEmbedding {
|
||||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
|
let rope_theta = cfg.rope_theta as f32;
|
||||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||||
let max_seq_len = cfg.max_position_embeddings;
|
let max_seq_len = cfg.max_position_embeddings;
|
||||||
let inv_freq: Vec<_> = (0..dim)
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
.step_by(2)
|
.step_by(2)
|
||||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
.map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
|
||||||
.collect();
|
.collect();
|
||||||
let inv_freq_len = inv_freq.len();
|
let inv_freq_len = inv_freq.len();
|
||||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||||
@ -353,7 +359,7 @@ pub struct Model {
|
|||||||
layers: Vec<DecoderLayer>,
|
layers: Vec<DecoderLayer>,
|
||||||
norm: RmsNorm,
|
norm: RmsNorm,
|
||||||
lm_head: Linear,
|
lm_head: Linear,
|
||||||
sliding_window: usize,
|
sliding_window: Option<usize>,
|
||||||
device: Device,
|
device: Device,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
}
|
}
|
||||||
@ -388,11 +394,11 @@ impl Model {
|
|||||||
tgt_len: usize,
|
tgt_len: usize,
|
||||||
seqlen_offset: usize,
|
seqlen_offset: usize,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
// Sliding window mask?
|
let sliding_window = self.sliding_window.unwrap_or(tgt_len + 1);
|
||||||
let mask: Vec<_> = (0..tgt_len)
|
let mask: Vec<_> = (0..tgt_len)
|
||||||
.flat_map(|i| {
|
.flat_map(|i| {
|
||||||
(0..tgt_len).map(move |j| {
|
(0..tgt_len).map(move |j| {
|
||||||
if i < j || j + self.sliding_window < i {
|
if i < j || j + sliding_window < i {
|
||||||
f32::NEG_INFINITY
|
f32::NEG_INFINITY
|
||||||
} else {
|
} else {
|
||||||
0.
|
0.
|
||||||
|
@ -21,11 +21,12 @@ fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
|||||||
|
|
||||||
impl RotaryEmbedding {
|
impl RotaryEmbedding {
|
||||||
fn new(cfg: &Config, dev: &Device) -> Result<Self> {
|
fn new(cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
|
let rope_theta = cfg.rope_theta as f32;
|
||||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||||
let max_seq_len = cfg.max_position_embeddings;
|
let max_seq_len = cfg.max_position_embeddings;
|
||||||
let inv_freq: Vec<_> = (0..dim)
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
.step_by(2)
|
.step_by(2)
|
||||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
.map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
|
||||||
.collect();
|
.collect();
|
||||||
let inv_freq_len = inv_freq.len();
|
let inv_freq_len = inv_freq.len();
|
||||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
|
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
|
||||||
@ -257,7 +258,7 @@ pub struct Model {
|
|||||||
layers: Vec<DecoderLayer>,
|
layers: Vec<DecoderLayer>,
|
||||||
norm: RmsNorm,
|
norm: RmsNorm,
|
||||||
lm_head: Linear,
|
lm_head: Linear,
|
||||||
sliding_window: usize,
|
sliding_window: Option<usize>,
|
||||||
device: Device,
|
device: Device,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -290,11 +291,11 @@ impl Model {
|
|||||||
tgt_len: usize,
|
tgt_len: usize,
|
||||||
seqlen_offset: usize,
|
seqlen_offset: usize,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
// Sliding window mask?
|
let sliding_window = self.sliding_window.unwrap_or(tgt_len + 1);
|
||||||
let mask: Vec<_> = (0..tgt_len)
|
let mask: Vec<_> = (0..tgt_len)
|
||||||
.flat_map(|i| {
|
.flat_map(|i| {
|
||||||
(0..tgt_len).map(move |j| {
|
(0..tgt_len).map(move |j| {
|
||||||
if i < j || j + self.sliding_window < i {
|
if i < j || j + sliding_window < i {
|
||||||
f32::NEG_INFINITY
|
f32::NEG_INFINITY
|
||||||
} else {
|
} else {
|
||||||
0.
|
0.
|
||||||
|
Reference in New Issue
Block a user