mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Support for the new Qwen2 models. (#2257)
* Support for the new Qwen2 models. * Add more models.
This commit is contained in:
@ -144,6 +144,14 @@ enum WhichModel {
|
|||||||
W72b,
|
W72b,
|
||||||
#[value(name = "moe-a2.7b")]
|
#[value(name = "moe-a2.7b")]
|
||||||
MoeA27b,
|
MoeA27b,
|
||||||
|
#[value(name = "2-0.5b")]
|
||||||
|
W2_0_5b,
|
||||||
|
#[value(name = "2-1.5b")]
|
||||||
|
W2_1_5b,
|
||||||
|
#[value(name = "2-7b")]
|
||||||
|
W2_7b,
|
||||||
|
#[value(name = "2-72b")]
|
||||||
|
W2_72b,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -234,16 +242,20 @@ fn main() -> Result<()> {
|
|||||||
let model_id = match args.model_id {
|
let model_id = match args.model_id {
|
||||||
Some(model_id) => model_id,
|
Some(model_id) => model_id,
|
||||||
None => {
|
None => {
|
||||||
let size = match args.model {
|
let (version, size) = match args.model {
|
||||||
WhichModel::W0_5b => "0.5B",
|
WhichModel::W2_0_5b => ("2", "0.5B"),
|
||||||
WhichModel::W1_8b => "1.8B",
|
WhichModel::W2_1_5b => ("2", "1.5B"),
|
||||||
WhichModel::W4b => "4B",
|
WhichModel::W2_7b => ("2", "7B"),
|
||||||
WhichModel::W7b => "7B",
|
WhichModel::W2_72b => ("2", "72B"),
|
||||||
WhichModel::W14b => "14B",
|
WhichModel::W0_5b => ("1.5", "0.5B"),
|
||||||
WhichModel::W72b => "72B",
|
WhichModel::W1_8b => ("1.5", "1.8B"),
|
||||||
WhichModel::MoeA27b => "MoE-A2.7B",
|
WhichModel::W4b => ("1.5", "4B"),
|
||||||
|
WhichModel::W7b => ("1.5", "7B"),
|
||||||
|
WhichModel::W14b => ("1.5", "14B"),
|
||||||
|
WhichModel::W72b => ("1.5", "72B"),
|
||||||
|
WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"),
|
||||||
};
|
};
|
||||||
format!("Qwen/Qwen1.5-{size}")
|
format!("Qwen/Qwen{version}-{size}")
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let repo = api.repo(Repo::with_revision(
|
let repo = api.repo(Repo::with_revision(
|
||||||
@ -261,11 +273,15 @@ fn main() -> Result<()> {
|
|||||||
.map(std::path::PathBuf::from)
|
.map(std::path::PathBuf::from)
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
None => match args.model {
|
None => match args.model {
|
||||||
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
|
WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => {
|
||||||
|
vec![repo.get("model.safetensors")?]
|
||||||
|
}
|
||||||
WhichModel::W4b
|
WhichModel::W4b
|
||||||
| WhichModel::W7b
|
| WhichModel::W7b
|
||||||
|
| WhichModel::W2_7b
|
||||||
| WhichModel::W14b
|
| WhichModel::W14b
|
||||||
| WhichModel::W72b
|
| WhichModel::W72b
|
||||||
|
| WhichModel::W2_72b
|
||||||
| WhichModel::MoeA27b => {
|
| WhichModel::MoeA27b => {
|
||||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||||
}
|
}
|
||||||
|
@ -360,8 +360,12 @@ pub struct ModelForCausalLM {
|
|||||||
|
|
||||||
impl ModelForCausalLM {
|
impl ModelForCausalLM {
|
||||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
let base_model = Model::new(cfg, vb.clone())?;
|
||||||
let base_model = Model::new(cfg, vb)?;
|
let lm_head = if vb.contains_tensor("lm_head") {
|
||||||
|
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
||||||
|
} else {
|
||||||
|
Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None)
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
base_model,
|
base_model,
|
||||||
lm_head,
|
lm_head,
|
||||||
|
Reference in New Issue
Block a user