Add Qwen3 MoE (#2934)

* qwen-moe rebase

* lint

* fixed rebase error

* swapped normal MoE model with CausalMoE Model in example, and swapped the tie word embeddings if statement

* updated readme
This commit is contained in:
Kyle Birnbaum
2025-05-31 06:33:28 -07:00
committed by GitHub
parent cd7b877d6b
commit 0224a749f0
4 changed files with 393 additions and 1 deletions

View File

@ -25,3 +25,28 @@ def print_prime(n: int): # n is the number of primes to be printed
print(i) print(i)
``` ```
The qwen3 MoE variant is also an option.
```bash
$ cargo run --example qwen --features metal --release -- --prompt "Write a poem about butterflies. <think></think>." --model "3-moe-a3b"
> In morning's hush, where daisies sleep,
> A fleeting dance through sunlit deep—
> They flutter soft on gossamer thread,
> The messengers of springs own head.
>
> With painted sails and delicate grace,
> They drift from bloom to blossom's face.
> Each wing a tale in hues unseen,
> Of ancient dreams and secrets between.
>
> No sound they make, yet still they speak—
> Of time that flies, of life so brief.
> A fleeting kiss on summers breath,
> A whisper lost before death.
>
> Yet in their flight, the soul takes wing,
> And for a moment, all is spring.
> For though they fade, they never die—
> Their beauty lives where hearts can fly.
> 161 tokens generated (3.00 token/s)
```

View File

@ -10,6 +10,7 @@ use clap::Parser;
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase}; use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe}; use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
use candle_transformers::models::qwen3::{Config as Config3, ModelForCausalLM as Model3}; use candle_transformers::models::qwen3::{Config as Config3, ModelForCausalLM as Model3};
use candle_transformers::models::qwen3_moe::{Config as ConfigMoe3, ModelForCausalLM as ModelMoe3};
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream; use candle_examples::token_output_stream::TokenOutputStream;
@ -22,6 +23,7 @@ enum Model {
Base(ModelBase), Base(ModelBase),
Moe(ModelMoe), Moe(ModelMoe),
Base3(Model3), Base3(Model3),
Moe3(ModelMoe3),
} }
impl Model { impl Model {
@ -30,6 +32,7 @@ impl Model {
Self::Moe(ref mut m) => m.forward(xs, s), Self::Moe(ref mut m) => m.forward(xs, s),
Self::Base(ref mut m) => m.forward(xs, s), Self::Base(ref mut m) => m.forward(xs, s),
Self::Base3(ref mut m) => m.forward(xs, s), Self::Base3(ref mut m) => m.forward(xs, s),
Self::Moe3(ref mut m) => m.forward(xs, s),
} }
} }
} }
@ -167,6 +170,8 @@ enum WhichModel {
W3_4b, W3_4b,
#[value(name = "3-8b")] #[value(name = "3-8b")]
W3_8b, W3_8b,
#[value(name = "3-moe-a3b")]
W3MoeA3b,
} }
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -273,6 +278,7 @@ fn main() -> Result<()> {
WhichModel::W3_1_7b => ("3", "1.7B"), WhichModel::W3_1_7b => ("3", "1.7B"),
WhichModel::W3_4b => ("3", "4B"), WhichModel::W3_4b => ("3", "4B"),
WhichModel::W3_8b => ("3", "8B"), WhichModel::W3_8b => ("3", "8B"),
WhichModel::W3MoeA3b => ("3", "30B-A3B"),
}; };
format!("Qwen/Qwen{version}-{size}") format!("Qwen/Qwen{version}-{size}")
} }
@ -308,7 +314,8 @@ fn main() -> Result<()> {
| WhichModel::MoeA27b | WhichModel::MoeA27b
| WhichModel::W3_1_7b | WhichModel::W3_1_7b
| WhichModel::W3_4b | WhichModel::W3_4b
| WhichModel::W3_8b => { | WhichModel::W3_8b
| WhichModel::W3MoeA3b => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
} }
}, },
@ -334,6 +341,10 @@ fn main() -> Result<()> {
let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?; let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Base3(Model3::new(&config, vb)?) Model::Base3(Model3::new(&config, vb)?)
} }
WhichModel::W3MoeA3b => {
let config: ConfigMoe3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Moe3(ModelMoe3::new(&config, vb)?)
}
_ => { _ => {
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?; let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Base(ModelBase::new(&config, vb)?) Model::Base(ModelBase::new(&config, vb)?)

View File

@ -100,6 +100,7 @@ pub mod quantized_t5;
pub mod qwen2; pub mod qwen2;
pub mod qwen2_moe; pub mod qwen2_moe;
pub mod qwen3; pub mod qwen3;
pub mod qwen3_moe;
pub mod recurrent_gemma; pub mod recurrent_gemma;
pub mod repvgg; pub mod repvgg;
pub mod resnet; pub mod resnet;

View File

@ -0,0 +1,355 @@
use crate::models::{
qwen3::{Config as Qwen3Config, Qwen3Attention, Qwen3MLP, Qwen3RotaryEmbedding},
with_tracing::{linear_no_bias, Linear, RmsNorm},
};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub head_dim: usize,
pub attention_bias: bool,
pub num_key_value_heads: usize,
pub max_position_embeddings: usize,
pub sliding_window: Option<usize>,
pub max_window_layers: usize,
pub tie_word_embeddings: bool,
pub rope_theta: f64,
pub rms_norm_eps: f64,
pub use_sliding_window: bool,
pub hidden_act: Activation,
// MoE specific configuration
pub decoder_sparse_step: usize,
pub moe_intermediate_size: usize,
pub num_experts_per_tok: usize,
pub num_experts: usize,
pub norm_topk_prob: bool,
}
impl From<&Config> for Qwen3Config {
fn from(val: &Config) -> Self {
Qwen3Config {
vocab_size: val.vocab_size,
hidden_size: val.hidden_size,
intermediate_size: val.intermediate_size,
num_hidden_layers: val.num_hidden_layers,
num_attention_heads: val.num_attention_heads,
head_dim: val.head_dim,
attention_bias: val.attention_bias,
num_key_value_heads: val.num_key_value_heads,
max_position_embeddings: val.max_position_embeddings,
sliding_window: val.sliding_window,
max_window_layers: val.max_window_layers,
tie_word_embeddings: val.tie_word_embeddings,
rope_theta: val.rope_theta,
rms_norm_eps: val.rms_norm_eps,
use_sliding_window: val.use_sliding_window,
hidden_act: val.hidden_act,
}
}
}
#[derive(Debug, Clone)]
struct Qwen3MLPExpert {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
act_fn: Activation,
}
impl Qwen3MLPExpert {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
Ok(Self {
gate_proj: linear_no_bias(
cfg.hidden_size,
cfg.moe_intermediate_size,
vb.pp("gate_proj"),
)?,
up_proj: linear_no_bias(cfg.hidden_size, cfg.moe_intermediate_size, vb.pp("up_proj"))?,
down_proj: linear_no_bias(
cfg.moe_intermediate_size,
cfg.hidden_size,
vb.pp("down_proj"),
)?,
act_fn: cfg.hidden_act,
})
}
}
impl Module for Qwen3MLPExpert {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = x.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
}
}
// Qwen3 Sparse MoE Block implementation
#[derive(Debug, Clone)]
struct Qwen3SparseMoeBlock {
gate: Linear,
experts: Vec<Qwen3MLPExpert>,
norm_topk_prob: bool,
num_experts_per_tok: usize,
}
impl Qwen3SparseMoeBlock {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let gate = linear_no_bias(cfg.hidden_size, cfg.num_experts, vb.pp("gate"))?;
let mut experts = Vec::with_capacity(cfg.num_experts);
let vb_e = vb.pp("experts");
for idx in 0..cfg.num_experts {
let expert = Qwen3MLPExpert::new(cfg, vb_e.pp(idx))?;
experts.push(expert)
}
Ok(Self {
gate,
experts,
norm_topk_prob: cfg.norm_topk_prob,
num_experts_per_tok: cfg.num_experts_per_tok,
})
}
}
impl Module for Qwen3SparseMoeBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b_size, seq_len, hidden_dim) = xs.dims3()?;
let xs = xs.reshape(((), hidden_dim))?;
let router_logits = xs.apply(&self.gate)?;
let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
// Extract topk experts per token
let experts_per_tok = routing_weights
.arg_sort_last_dim(false)?
.narrow(D::Minus1, 0, self.num_experts_per_tok)?
.contiguous()?;
let routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?;
// Extract needed data
let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
let experts_per_tok = experts_per_tok.to_vec2::<u32>()?;
let mut top_x = vec![vec![]; self.experts.len()];
let mut selected_experts = vec![vec![]; self.experts.len()];
for (row_idx, (rw, expert_idxs)) in routing_weights
.iter()
.zip(experts_per_tok.iter())
.enumerate()
{
let sum_rw = rw.iter().sum::<f32>();
for (&rw, &expert_idx) in rw.iter().zip(expert_idxs.iter()) {
top_x[expert_idx as usize].push(row_idx as u32);
let rw = if self.norm_topk_prob { rw / sum_rw } else { rw };
selected_experts[expert_idx as usize].push(rw)
}
}
// Process through experts
let mut ys = xs.zeros_like()?;
for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
let top_x = &top_x[expert_idx];
if top_x.is_empty() {
continue;
}
let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
let selected_experts =
Tensor::new(selected_experts[expert_idx].as_slice(), xs.device())?
.reshape(((), 1))?
.to_dtype(xs.dtype())?;
let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
let current_hidden_states = expert_layer.forward(&current_state)?;
let current_hidden_states = current_hidden_states.broadcast_mul(&selected_experts)?;
ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
}
ys.reshape((b_size, seq_len, hidden_dim))
}
}
// MLP or MoE decision enum
#[derive(Debug, Clone)]
enum Qwen3FeedForward {
Mlp(Qwen3MLP),
MoE(Qwen3SparseMoeBlock),
}
impl Module for Qwen3FeedForward {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::Mlp(m) => m.forward(xs),
Self::MoE(m) => m.forward(xs),
}
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Qwen3Attention,
feed_forward: Qwen3FeedForward,
ln1: RmsNorm,
ln2: RmsNorm,
}
impl DecoderLayer {
fn new(
layer_idx: usize,
cfg: &Config,
rotary: Arc<Qwen3RotaryEmbedding>,
vb: VarBuilder,
) -> Result<Self> {
let self_attn = Qwen3Attention::new(&cfg.into(), rotary, vb.pp("self_attn"))?;
// Decide whether to use MoE or regular MLP based on layer_idx and decoder_sparse_step
let feed_forward = if cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0
{
Qwen3FeedForward::MoE(Qwen3SparseMoeBlock::new(cfg, vb.pp("mlp"))?)
} else {
Qwen3FeedForward::Mlp(Qwen3MLP::new(&cfg.into(), vb.pp("mlp"))?)
};
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let ln2 = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
feed_forward,
ln1,
ln2,
})
}
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
let h = self.ln1.forward(x)?;
let h = self.self_attn.forward(&h, mask, offset)?;
let x = (x + h)?;
let h2 = self.ln2.forward(&x)?;
let h2 = h2.apply(&self.feed_forward)?;
x + h2
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache();
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
device: Device,
dtype: DType,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
let rotary = Arc::new(Qwen3RotaryEmbedding::new(
vb.dtype(),
&cfg.into(),
vb.device(),
)?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb.pp("model.layers");
for i in 0..cfg.num_hidden_layers {
layers.push(DecoderLayer::new(i, cfg, rotary.clone(), vb_l.pp(i))?);
}
Ok(Self {
embed_tokens,
layers,
norm: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
fn clear_kv_cache(&mut self) {
for l in &mut self.layers {
l.clear_kv_cache();
}
}
fn causal_mask(
&self,
b: usize,
tgt: usize,
offset: usize,
sw: Option<usize>,
) -> Result<Tensor> {
let minf = f32::NEG_INFINITY;
let mask: Vec<_> = (0..tgt)
.flat_map(|i| {
(0..(tgt + offset)).map(move |j| {
let past_ok = j <= i + offset;
let sw_ok = match sw {
Some(w) => (i + offset) as i64 - j as i64 <= w as i64,
None => true,
};
if past_ok && sw_ok {
0.
} else {
minf
}
})
})
.collect();
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
let (b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?;
let causal = if l == 1 {
None
} else {
Some(self.causal_mask(b, l, offset, None)?)
};
for layer in &mut self.layers {
h = layer.forward(&h, causal.as_ref(), offset)?;
}
self.norm.forward(&h)
}
}
#[derive(Debug, Clone)]
pub struct ModelForCausalLM {
base: Model,
lm_head: Linear,
}
impl ModelForCausalLM {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let base = Model::new(cfg, vb.clone())?;
let lm_head = if cfg.tie_word_embeddings {
Linear::from_weights(base.embed_tokens.embeddings().clone(), None)
} else {
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
};
Ok(Self { base, lm_head })
}
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
let (_, l) = input.dims2()?;
self.base
.forward(input, offset)?
.narrow(1, l - 1, 1)?
.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
self.base.clear_kv_cache();
}
}