Support for the new Qwen2 models. (#2257)

* Support for the new Qwen2 models.

* Add more models.
This commit is contained in:
Laurent Mazare
2024-06-07 10:51:50 +01:00
committed by GitHub
parent b9fac7ec00
commit 54ff971e35
2 changed files with 32 additions and 12 deletions

View File

@ -144,6 +144,14 @@ enum WhichModel {
W72b, W72b,
#[value(name = "moe-a2.7b")] #[value(name = "moe-a2.7b")]
MoeA27b, MoeA27b,
#[value(name = "2-0.5b")]
W2_0_5b,
#[value(name = "2-1.5b")]
W2_1_5b,
#[value(name = "2-7b")]
W2_7b,
#[value(name = "2-72b")]
W2_72b,
} }
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -234,16 +242,20 @@ fn main() -> Result<()> {
let model_id = match args.model_id { let model_id = match args.model_id {
Some(model_id) => model_id, Some(model_id) => model_id,
None => { None => {
let size = match args.model { let (version, size) = match args.model {
WhichModel::W0_5b => "0.5B", WhichModel::W2_0_5b => ("2", "0.5B"),
WhichModel::W1_8b => "1.8B", WhichModel::W2_1_5b => ("2", "1.5B"),
WhichModel::W4b => "4B", WhichModel::W2_7b => ("2", "7B"),
WhichModel::W7b => "7B", WhichModel::W2_72b => ("2", "72B"),
WhichModel::W14b => "14B", WhichModel::W0_5b => ("1.5", "0.5B"),
WhichModel::W72b => "72B", WhichModel::W1_8b => ("1.5", "1.8B"),
WhichModel::MoeA27b => "MoE-A2.7B", WhichModel::W4b => ("1.5", "4B"),
WhichModel::W7b => ("1.5", "7B"),
WhichModel::W14b => ("1.5", "14B"),
WhichModel::W72b => ("1.5", "72B"),
WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"),
}; };
format!("Qwen/Qwen1.5-{size}") format!("Qwen/Qwen{version}-{size}")
} }
}; };
let repo = api.repo(Repo::with_revision( let repo = api.repo(Repo::with_revision(
@ -261,11 +273,15 @@ fn main() -> Result<()> {
.map(std::path::PathBuf::from) .map(std::path::PathBuf::from)
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
None => match args.model { None => match args.model {
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?], WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => {
vec![repo.get("model.safetensors")?]
}
WhichModel::W4b WhichModel::W4b
| WhichModel::W7b | WhichModel::W7b
| WhichModel::W2_7b
| WhichModel::W14b | WhichModel::W14b
| WhichModel::W72b | WhichModel::W72b
| WhichModel::W2_72b
| WhichModel::MoeA27b => { | WhichModel::MoeA27b => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
} }

View File

@ -360,8 +360,12 @@ pub struct ModelForCausalLM {
impl ModelForCausalLM { impl ModelForCausalLM {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let base_model = Model::new(cfg, vb.clone())?;
let base_model = Model::new(cfg, vb)?; let lm_head = if vb.contains_tensor("lm_head") {
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
} else {
Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None)
};
Ok(Self { Ok(Self {
base_model, base_model,
lm_head, lm_head,