Support for mistral-nemo. (#2396)

This commit is contained in:
Laurent Mazare
2024-08-04 18:52:40 +01:00
committed by GitHub
parent 89eae41efd
commit 2be9bd211e
2 changed files with 26 additions and 12 deletions

View File

@ -149,6 +149,10 @@ enum Which {
Mistral7bInstructV02, Mistral7bInstructV02,
#[value(name = "7b-maths-v0.1")] #[value(name = "7b-maths-v0.1")]
Mathstral7bV01, Mathstral7bV01,
#[value(name = "nemo-2407")]
MistralNemo2407,
#[value(name = "nemo-instruct-2407")]
MistralNemoInstruct2407,
} }
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -263,13 +267,16 @@ fn main() -> Result<()> {
} }
"lmz/candle-mistral".to_string() "lmz/candle-mistral".to_string()
} else { } else {
match args.which { let name = match args.which {
Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1".to_string(), Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1",
Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2".to_string(), Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2",
Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1".to_string(), Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1",
Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2".to_string(), Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2",
Which::Mathstral7bV01 => "mistralai/mathstral-7B-v0.1".to_string(), Which::Mathstral7bV01 => "mistralai/mathstral-7B-v0.1",
} Which::MistralNemo2407 => "mistralai/Mistral-Nemo-Base-2407",
Which::MistralNemoInstruct2407 => "mistralai/Mistral-Nemo-Instruct-2407",
};
name.to_string()
} }
} }
}; };

View File

@ -15,6 +15,7 @@ pub struct Config {
pub intermediate_size: usize, pub intermediate_size: usize,
pub num_hidden_layers: usize, pub num_hidden_layers: usize,
pub num_attention_heads: usize, pub num_attention_heads: usize,
pub head_dim: Option<usize>,
pub num_key_value_heads: usize, pub num_key_value_heads: usize,
pub hidden_act: Activation, pub hidden_act: Activation,
pub max_position_embeddings: usize, pub max_position_embeddings: usize,
@ -34,6 +35,7 @@ impl Config {
intermediate_size: 14336, intermediate_size: 14336,
num_hidden_layers: 32, num_hidden_layers: 32,
num_attention_heads: 32, num_attention_heads: 32,
head_dim: None,
num_key_value_heads: 8, num_key_value_heads: 8,
hidden_act: Activation::Silu, hidden_act: Activation::Silu,
max_position_embeddings: 32768, max_position_embeddings: 32768,
@ -53,6 +55,7 @@ impl Config {
intermediate_size: 14336, intermediate_size: 14336,
num_hidden_layers: 32, num_hidden_layers: 32,
num_attention_heads: 32, num_attention_heads: 32,
head_dim: None,
num_key_value_heads: 8, num_key_value_heads: 8,
hidden_act: Activation::Silu, hidden_act: Activation::Silu,
max_position_embeddings: 32768, max_position_embeddings: 32768,
@ -71,6 +74,7 @@ impl Config {
intermediate_size: 14336, intermediate_size: 14336,
num_hidden_layers: 32, num_hidden_layers: 32,
num_attention_heads: 32, num_attention_heads: 32,
head_dim: None,
num_key_value_heads: 8, num_key_value_heads: 8,
hidden_act: Activation::Silu, hidden_act: Activation::Silu,
max_position_embeddings: 32768, max_position_embeddings: 32768,
@ -80,6 +84,11 @@ impl Config {
use_flash_attn, use_flash_attn,
} }
} }
fn head_dim(&self) -> usize {
self.head_dim
.unwrap_or(self.hidden_size / self.num_attention_heads)
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -91,7 +100,7 @@ struct RotaryEmbedding {
impl RotaryEmbedding { impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> { fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let rope_theta = cfg.rope_theta as f32; let rope_theta = cfg.rope_theta as f32;
let dim = cfg.hidden_size / cfg.num_attention_heads; let dim = cfg.head_dim();
let max_seq_len = cfg.max_position_embeddings; let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim) let inv_freq: Vec<_> = (0..dim)
.step_by(2) .step_by(2)
@ -183,7 +192,6 @@ struct Attention {
num_kv_heads: usize, num_kv_heads: usize,
num_kv_groups: usize, num_kv_groups: usize,
head_dim: usize, head_dim: usize,
hidden_size: usize,
rotary_emb: Arc<RotaryEmbedding>, rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>, kv_cache: Option<(Tensor, Tensor)>,
use_flash_attn: bool, use_flash_attn: bool,
@ -195,7 +203,7 @@ impl Attention {
let num_heads = cfg.num_attention_heads; let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads; let num_kv_heads = cfg.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads; let num_kv_groups = num_heads / num_kv_heads;
let head_dim = hidden_sz / num_heads; let head_dim = cfg.head_dim();
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
@ -209,7 +217,6 @@ impl Attention {
num_kv_heads, num_kv_heads,
num_kv_groups, num_kv_groups,
head_dim, head_dim,
hidden_size: hidden_sz,
rotary_emb, rotary_emb,
kv_cache: None, kv_cache: None,
use_flash_attn: cfg.use_flash_attn, use_flash_attn: cfg.use_flash_attn,
@ -277,7 +284,7 @@ impl Attention {
}; };
attn_output attn_output
.transpose(1, 2)? .transpose(1, 2)?
.reshape((b_sz, q_len, self.hidden_size))? .reshape((b_sz, q_len, self.num_heads * self.head_dim))?
.apply(&self.o_proj) .apply(&self.o_proj)
} }