Add Qwen3 MoE (#2934)

* qwen-moe rebase

* lint

* fixed rebase error

* swapped normal MoE model with CausalMoE Model in example, and swapped the tie word embeddings if statement

* updated readme
This commit is contained in:
Kyle Birnbaum
2025-05-31 06:33:28 -07:00
committed by GitHub
parent cd7b877d6b
commit 0224a749f0
4 changed files with 393 additions and 1 deletions

View File

@ -10,6 +10,7 @@ use clap::Parser;
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
use candle_transformers::models::qwen3::{Config as Config3, ModelForCausalLM as Model3};
use candle_transformers::models::qwen3_moe::{Config as ConfigMoe3, ModelForCausalLM as ModelMoe3};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@ -22,6 +23,7 @@ enum Model {
Base(ModelBase),
Moe(ModelMoe),
Base3(Model3),
Moe3(ModelMoe3),
}
impl Model {
@ -30,6 +32,7 @@ impl Model {
Self::Moe(ref mut m) => m.forward(xs, s),
Self::Base(ref mut m) => m.forward(xs, s),
Self::Base3(ref mut m) => m.forward(xs, s),
Self::Moe3(ref mut m) => m.forward(xs, s),
}
}
}
@ -167,6 +170,8 @@ enum WhichModel {
W3_4b,
#[value(name = "3-8b")]
W3_8b,
#[value(name = "3-moe-a3b")]
W3MoeA3b,
}
#[derive(Parser, Debug)]
@ -273,6 +278,7 @@ fn main() -> Result<()> {
WhichModel::W3_1_7b => ("3", "1.7B"),
WhichModel::W3_4b => ("3", "4B"),
WhichModel::W3_8b => ("3", "8B"),
WhichModel::W3MoeA3b => ("3", "30B-A3B"),
};
format!("Qwen/Qwen{version}-{size}")
}
@ -308,7 +314,8 @@ fn main() -> Result<()> {
| WhichModel::MoeA27b
| WhichModel::W3_1_7b
| WhichModel::W3_4b
| WhichModel::W3_8b => {
| WhichModel::W3_8b
| WhichModel::W3MoeA3b => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
@ -334,6 +341,10 @@ fn main() -> Result<()> {
let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Base3(Model3::new(&config, vb)?)
}
WhichModel::W3MoeA3b => {
let config: ConfigMoe3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Moe3(ModelMoe3::new(&config, vb)?)
}
_ => {
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Base(ModelBase::new(&config, vb)?)