mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Support for mistral-nemo. (#2396)
This commit is contained in:
@ -149,6 +149,10 @@ enum Which {
|
|||||||
Mistral7bInstructV02,
|
Mistral7bInstructV02,
|
||||||
#[value(name = "7b-maths-v0.1")]
|
#[value(name = "7b-maths-v0.1")]
|
||||||
Mathstral7bV01,
|
Mathstral7bV01,
|
||||||
|
#[value(name = "nemo-2407")]
|
||||||
|
MistralNemo2407,
|
||||||
|
#[value(name = "nemo-instruct-2407")]
|
||||||
|
MistralNemoInstruct2407,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -263,13 +267,16 @@ fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
"lmz/candle-mistral".to_string()
|
"lmz/candle-mistral".to_string()
|
||||||
} else {
|
} else {
|
||||||
match args.which {
|
let name = match args.which {
|
||||||
Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1".to_string(),
|
Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1",
|
||||||
Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2".to_string(),
|
Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2",
|
||||||
Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1".to_string(),
|
Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1",
|
||||||
Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2".to_string(),
|
Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2",
|
||||||
Which::Mathstral7bV01 => "mistralai/mathstral-7B-v0.1".to_string(),
|
Which::Mathstral7bV01 => "mistralai/mathstral-7B-v0.1",
|
||||||
}
|
Which::MistralNemo2407 => "mistralai/Mistral-Nemo-Base-2407",
|
||||||
|
Which::MistralNemoInstruct2407 => "mistralai/Mistral-Nemo-Instruct-2407",
|
||||||
|
};
|
||||||
|
name.to_string()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -15,6 +15,7 @@ pub struct Config {
|
|||||||
pub intermediate_size: usize,
|
pub intermediate_size: usize,
|
||||||
pub num_hidden_layers: usize,
|
pub num_hidden_layers: usize,
|
||||||
pub num_attention_heads: usize,
|
pub num_attention_heads: usize,
|
||||||
|
pub head_dim: Option<usize>,
|
||||||
pub num_key_value_heads: usize,
|
pub num_key_value_heads: usize,
|
||||||
pub hidden_act: Activation,
|
pub hidden_act: Activation,
|
||||||
pub max_position_embeddings: usize,
|
pub max_position_embeddings: usize,
|
||||||
@ -34,6 +35,7 @@ impl Config {
|
|||||||
intermediate_size: 14336,
|
intermediate_size: 14336,
|
||||||
num_hidden_layers: 32,
|
num_hidden_layers: 32,
|
||||||
num_attention_heads: 32,
|
num_attention_heads: 32,
|
||||||
|
head_dim: None,
|
||||||
num_key_value_heads: 8,
|
num_key_value_heads: 8,
|
||||||
hidden_act: Activation::Silu,
|
hidden_act: Activation::Silu,
|
||||||
max_position_embeddings: 32768,
|
max_position_embeddings: 32768,
|
||||||
@ -53,6 +55,7 @@ impl Config {
|
|||||||
intermediate_size: 14336,
|
intermediate_size: 14336,
|
||||||
num_hidden_layers: 32,
|
num_hidden_layers: 32,
|
||||||
num_attention_heads: 32,
|
num_attention_heads: 32,
|
||||||
|
head_dim: None,
|
||||||
num_key_value_heads: 8,
|
num_key_value_heads: 8,
|
||||||
hidden_act: Activation::Silu,
|
hidden_act: Activation::Silu,
|
||||||
max_position_embeddings: 32768,
|
max_position_embeddings: 32768,
|
||||||
@ -71,6 +74,7 @@ impl Config {
|
|||||||
intermediate_size: 14336,
|
intermediate_size: 14336,
|
||||||
num_hidden_layers: 32,
|
num_hidden_layers: 32,
|
||||||
num_attention_heads: 32,
|
num_attention_heads: 32,
|
||||||
|
head_dim: None,
|
||||||
num_key_value_heads: 8,
|
num_key_value_heads: 8,
|
||||||
hidden_act: Activation::Silu,
|
hidden_act: Activation::Silu,
|
||||||
max_position_embeddings: 32768,
|
max_position_embeddings: 32768,
|
||||||
@ -80,6 +84,11 @@ impl Config {
|
|||||||
use_flash_attn,
|
use_flash_attn,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn head_dim(&self) -> usize {
|
||||||
|
self.head_dim
|
||||||
|
.unwrap_or(self.hidden_size / self.num_attention_heads)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -91,7 +100,7 @@ struct RotaryEmbedding {
|
|||||||
impl RotaryEmbedding {
|
impl RotaryEmbedding {
|
||||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
let rope_theta = cfg.rope_theta as f32;
|
let rope_theta = cfg.rope_theta as f32;
|
||||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
let dim = cfg.head_dim();
|
||||||
let max_seq_len = cfg.max_position_embeddings;
|
let max_seq_len = cfg.max_position_embeddings;
|
||||||
let inv_freq: Vec<_> = (0..dim)
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
.step_by(2)
|
.step_by(2)
|
||||||
@ -183,7 +192,6 @@ struct Attention {
|
|||||||
num_kv_heads: usize,
|
num_kv_heads: usize,
|
||||||
num_kv_groups: usize,
|
num_kv_groups: usize,
|
||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
hidden_size: usize,
|
|
||||||
rotary_emb: Arc<RotaryEmbedding>,
|
rotary_emb: Arc<RotaryEmbedding>,
|
||||||
kv_cache: Option<(Tensor, Tensor)>,
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
@ -195,7 +203,7 @@ impl Attention {
|
|||||||
let num_heads = cfg.num_attention_heads;
|
let num_heads = cfg.num_attention_heads;
|
||||||
let num_kv_heads = cfg.num_key_value_heads;
|
let num_kv_heads = cfg.num_key_value_heads;
|
||||||
let num_kv_groups = num_heads / num_kv_heads;
|
let num_kv_groups = num_heads / num_kv_heads;
|
||||||
let head_dim = hidden_sz / num_heads;
|
let head_dim = cfg.head_dim();
|
||||||
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||||
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||||
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||||
@ -209,7 +217,6 @@ impl Attention {
|
|||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
num_kv_groups,
|
num_kv_groups,
|
||||||
head_dim,
|
head_dim,
|
||||||
hidden_size: hidden_sz,
|
|
||||||
rotary_emb,
|
rotary_emb,
|
||||||
kv_cache: None,
|
kv_cache: None,
|
||||||
use_flash_attn: cfg.use_flash_attn,
|
use_flash_attn: cfg.use_flash_attn,
|
||||||
@ -277,7 +284,7 @@ impl Attention {
|
|||||||
};
|
};
|
||||||
attn_output
|
attn_output
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
.reshape((b_sz, q_len, self.hidden_size))?
|
.reshape((b_sz, q_len, self.num_heads * self.head_dim))?
|
||||||
.apply(&self.o_proj)
|
.apply(&self.o_proj)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user