mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add some llama-3.2 examples. (#2508)
* Add some llama-3.2 examples. * Support tie-word-embeddings for llama.
This commit is contained in:
@ -35,6 +35,10 @@ enum Which {
|
|||||||
V31,
|
V31,
|
||||||
V3Instruct,
|
V3Instruct,
|
||||||
V31Instruct,
|
V31Instruct,
|
||||||
|
V32_1b,
|
||||||
|
V32_1bInstruct,
|
||||||
|
V32_3b,
|
||||||
|
V32_3bInstruct,
|
||||||
#[value(name = "solar-10.7b")]
|
#[value(name = "solar-10.7b")]
|
||||||
Solar10_7B,
|
Solar10_7B,
|
||||||
#[value(name = "tiny-llama-1.1b-chat")]
|
#[value(name = "tiny-llama-1.1b-chat")]
|
||||||
@ -137,6 +141,10 @@ fn main() -> Result<()> {
|
|||||||
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
|
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
|
||||||
Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(),
|
Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(),
|
||||||
Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(),
|
Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(),
|
||||||
|
Which::V32_1b => "meta-llama/Llama-3.2-1B".to_string(),
|
||||||
|
Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct".to_string(),
|
||||||
|
Which::V32_3b => "meta-llama/Llama-3.2-3B".to_string(),
|
||||||
|
Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct".to_string(),
|
||||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||||
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
||||||
});
|
});
|
||||||
@ -156,10 +164,14 @@ fn main() -> Result<()> {
|
|||||||
| Which::V3Instruct
|
| Which::V3Instruct
|
||||||
| Which::V31
|
| Which::V31
|
||||||
| Which::V31Instruct
|
| Which::V31Instruct
|
||||||
|
| Which::V32_3b
|
||||||
|
| Which::V32_3bInstruct
|
||||||
| Which::Solar10_7B => {
|
| Which::Solar10_7B => {
|
||||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||||
}
|
}
|
||||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
Which::V32_1b | Which::V32_1bInstruct | Which::TinyLlama1_1BChat => {
|
||||||
|
vec![api.get("model.safetensors")?]
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||||
|
|
||||||
|
@ -44,6 +44,7 @@ pub struct LlamaConfig {
|
|||||||
pub eos_token_id: Option<LlamaEosToks>,
|
pub eos_token_id: Option<LlamaEosToks>,
|
||||||
pub rope_scaling: Option<Llama3RopeConfig>,
|
pub rope_scaling: Option<Llama3RopeConfig>,
|
||||||
pub max_position_embeddings: usize,
|
pub max_position_embeddings: usize,
|
||||||
|
pub tie_word_embeddings: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LlamaConfig {
|
impl LlamaConfig {
|
||||||
@ -72,6 +73,7 @@ impl LlamaConfig {
|
|||||||
eos_token_id: self.eos_token_id,
|
eos_token_id: self.eos_token_id,
|
||||||
rope_scaling: self.rope_scaling,
|
rope_scaling: self.rope_scaling,
|
||||||
max_position_embeddings: self.max_position_embeddings,
|
max_position_embeddings: self.max_position_embeddings,
|
||||||
|
tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -91,6 +93,7 @@ pub struct Config {
|
|||||||
pub eos_token_id: Option<LlamaEosToks>,
|
pub eos_token_id: Option<LlamaEosToks>,
|
||||||
pub rope_scaling: Option<Llama3RopeConfig>,
|
pub rope_scaling: Option<Llama3RopeConfig>,
|
||||||
pub max_position_embeddings: usize,
|
pub max_position_embeddings: usize,
|
||||||
|
pub tie_word_embeddings: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
@ -109,6 +112,7 @@ impl Config {
|
|||||||
eos_token_id: None,
|
eos_token_id: None,
|
||||||
rope_scaling: None,
|
rope_scaling: None,
|
||||||
max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
|
max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
|
||||||
|
tie_word_embeddings: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -127,6 +131,7 @@ impl Config {
|
|||||||
eos_token_id: None,
|
eos_token_id: None,
|
||||||
rope_scaling: None,
|
rope_scaling: None,
|
||||||
max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
|
max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
|
||||||
|
tie_word_embeddings: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -504,7 +509,11 @@ impl Llama {
|
|||||||
|
|
||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
|
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
|
||||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
let lm_head = if cfg.tie_word_embeddings {
|
||||||
|
Linear::from_weights(wte.embeddings().clone(), None)
|
||||||
|
} else {
|
||||||
|
linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
||||||
|
};
|
||||||
let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
||||||
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
|
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
|
||||||
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap())
|
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap())
|
||||||
|
@ -43,6 +43,7 @@ pub struct LLaVAConfig {
|
|||||||
pub image_token_index: isize,
|
pub image_token_index: isize,
|
||||||
#[serde(default = "default_hf")]
|
#[serde(default = "default_hf")]
|
||||||
pub hf: bool,
|
pub hf: bool,
|
||||||
|
pub tie_word_embeddings: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_hf() -> bool {
|
fn default_hf() -> bool {
|
||||||
@ -77,6 +78,7 @@ impl LLaVAConfig {
|
|||||||
use_flash_attn: false,
|
use_flash_attn: false,
|
||||||
rope_scaling: None, // Assume we don't have LLaVA for Llama 3.1
|
rope_scaling: None, // Assume we don't have LLaVA for Llama 3.1
|
||||||
max_position_embeddings: self.max_position_embeddings,
|
max_position_embeddings: self.max_position_embeddings,
|
||||||
|
tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -264,6 +266,7 @@ impl HFLLaVAConfig {
|
|||||||
use_cache: self.text_config.use_cache,
|
use_cache: self.text_config.use_cache,
|
||||||
vocab_size: self.vocab_size,
|
vocab_size: self.vocab_size,
|
||||||
image_token_index: self.image_token_index,
|
image_token_index: self.image_token_index,
|
||||||
|
tie_word_embeddings: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user