mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add Qwen3 MoE (#2934)
* qwen-moe rebase * lint * fixed rebase error * swapped normal MoE model with CausalMoE Model in example, and swapped the tie word embeddings if statement * updated readme
This commit is contained in:
@ -10,6 +10,7 @@ use clap::Parser;
|
||||
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
|
||||
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
|
||||
use candle_transformers::models::qwen3::{Config as Config3, ModelForCausalLM as Model3};
|
||||
use candle_transformers::models::qwen3_moe::{Config as ConfigMoe3, ModelForCausalLM as ModelMoe3};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
@ -22,6 +23,7 @@ enum Model {
|
||||
Base(ModelBase),
|
||||
Moe(ModelMoe),
|
||||
Base3(Model3),
|
||||
Moe3(ModelMoe3),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@ -30,6 +32,7 @@ impl Model {
|
||||
Self::Moe(ref mut m) => m.forward(xs, s),
|
||||
Self::Base(ref mut m) => m.forward(xs, s),
|
||||
Self::Base3(ref mut m) => m.forward(xs, s),
|
||||
Self::Moe3(ref mut m) => m.forward(xs, s),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -167,6 +170,8 @@ enum WhichModel {
|
||||
W3_4b,
|
||||
#[value(name = "3-8b")]
|
||||
W3_8b,
|
||||
#[value(name = "3-moe-a3b")]
|
||||
W3MoeA3b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -273,6 +278,7 @@ fn main() -> Result<()> {
|
||||
WhichModel::W3_1_7b => ("3", "1.7B"),
|
||||
WhichModel::W3_4b => ("3", "4B"),
|
||||
WhichModel::W3_8b => ("3", "8B"),
|
||||
WhichModel::W3MoeA3b => ("3", "30B-A3B"),
|
||||
};
|
||||
format!("Qwen/Qwen{version}-{size}")
|
||||
}
|
||||
@ -308,7 +314,8 @@ fn main() -> Result<()> {
|
||||
| WhichModel::MoeA27b
|
||||
| WhichModel::W3_1_7b
|
||||
| WhichModel::W3_4b
|
||||
| WhichModel::W3_8b => {
|
||||
| WhichModel::W3_8b
|
||||
| WhichModel::W3MoeA3b => {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
}
|
||||
},
|
||||
@ -334,6 +341,10 @@ fn main() -> Result<()> {
|
||||
let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
Model::Base3(Model3::new(&config, vb)?)
|
||||
}
|
||||
WhichModel::W3MoeA3b => {
|
||||
let config: ConfigMoe3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
Model::Moe3(ModelMoe3::new(&config, vb)?)
|
||||
}
|
||||
_ => {
|
||||
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
Model::Base(ModelBase::new(&config, vb)?)
|
||||
|
Reference in New Issue
Block a user