diff --git a/candle-examples/examples/qwen/README.md b/candle-examples/examples/qwen/README.md
index cb785f21..d81cd666 100644
--- a/candle-examples/examples/qwen/README.md
+++ b/candle-examples/examples/qwen/README.md
@@ -25,3 +25,28 @@ def print_prime(n: int): # n is the number of primes to be printed
print(i)
```
+The qwen3 MoE variant is also an option.
+
+```bash
+$ cargo run --example qwen --features metal --release -- --prompt "Write a poem about butterflies. ." --model "3-moe-a3b"
+> In morning's hush, where daisies sleep,
+> A fleeting dance through sunlit deep—
+> They flutter soft on gossamer thread,
+> The messengers of spring’s own head.
+>
+> With painted sails and delicate grace,
+> They drift from bloom to blossom's face.
+> Each wing a tale in hues unseen,
+> Of ancient dreams and secrets between.
+>
+> No sound they make, yet still they speak—
+> Of time that flies, of life so brief.
+> A fleeting kiss on summer’s breath,
+> A whisper lost before death.
+>
+> Yet in their flight, the soul takes wing,
+> And for a moment, all is spring.
+> For though they fade, they never die—
+> Their beauty lives where hearts can fly.
+> 161 tokens generated (3.00 token/s)
+```
diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs
index d0e179e0..3b90b9fb 100644
--- a/candle-examples/examples/qwen/main.rs
+++ b/candle-examples/examples/qwen/main.rs
@@ -10,6 +10,7 @@ use clap::Parser;
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
use candle_transformers::models::qwen3::{Config as Config3, ModelForCausalLM as Model3};
+use candle_transformers::models::qwen3_moe::{Config as ConfigMoe3, ModelForCausalLM as ModelMoe3};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@@ -22,6 +23,7 @@ enum Model {
Base(ModelBase),
Moe(ModelMoe),
Base3(Model3),
+ Moe3(ModelMoe3),
}
impl Model {
@@ -30,6 +32,7 @@ impl Model {
Self::Moe(ref mut m) => m.forward(xs, s),
Self::Base(ref mut m) => m.forward(xs, s),
Self::Base3(ref mut m) => m.forward(xs, s),
+ Self::Moe3(ref mut m) => m.forward(xs, s),
}
}
}
@@ -167,6 +170,8 @@ enum WhichModel {
W3_4b,
#[value(name = "3-8b")]
W3_8b,
+ #[value(name = "3-moe-a3b")]
+ W3MoeA3b,
}
#[derive(Parser, Debug)]
@@ -273,6 +278,7 @@ fn main() -> Result<()> {
WhichModel::W3_1_7b => ("3", "1.7B"),
WhichModel::W3_4b => ("3", "4B"),
WhichModel::W3_8b => ("3", "8B"),
+ WhichModel::W3MoeA3b => ("3", "30B-A3B"),
};
format!("Qwen/Qwen{version}-{size}")
}
@@ -308,7 +314,8 @@ fn main() -> Result<()> {
| WhichModel::MoeA27b
| WhichModel::W3_1_7b
| WhichModel::W3_4b
- | WhichModel::W3_8b => {
+ | WhichModel::W3_8b
+ | WhichModel::W3MoeA3b => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
@@ -334,6 +341,10 @@ fn main() -> Result<()> {
let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Base3(Model3::new(&config, vb)?)
}
+ WhichModel::W3MoeA3b => {
+ let config: ConfigMoe3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
+ Model::Moe3(ModelMoe3::new(&config, vb)?)
+ }
_ => {
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Base(ModelBase::new(&config, vb)?)
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index d8f71b44..8d80b183 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -100,6 +100,7 @@ pub mod quantized_t5;
pub mod qwen2;
pub mod qwen2_moe;
pub mod qwen3;
+pub mod qwen3_moe;
pub mod recurrent_gemma;
pub mod repvgg;
pub mod resnet;
diff --git a/candle-transformers/src/models/qwen3_moe.rs b/candle-transformers/src/models/qwen3_moe.rs
new file mode 100644
index 00000000..e88a0538
--- /dev/null
+++ b/candle-transformers/src/models/qwen3_moe.rs
@@ -0,0 +1,355 @@
+use crate::models::{
+ qwen3::{Config as Qwen3Config, Qwen3Attention, Qwen3MLP, Qwen3RotaryEmbedding},
+ with_tracing::{linear_no_bias, Linear, RmsNorm},
+};
+use candle::{DType, Device, Module, Result, Tensor, D};
+use candle_nn::{Activation, VarBuilder};
+use std::sync::Arc;
+
+#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
+pub struct Config {
+ pub vocab_size: usize,
+ pub hidden_size: usize,
+ pub intermediate_size: usize,
+ pub num_hidden_layers: usize,
+ pub num_attention_heads: usize,
+ pub head_dim: usize,
+ pub attention_bias: bool,
+ pub num_key_value_heads: usize,
+ pub max_position_embeddings: usize,
+ pub sliding_window: Option,
+ pub max_window_layers: usize,
+ pub tie_word_embeddings: bool,
+ pub rope_theta: f64,
+ pub rms_norm_eps: f64,
+ pub use_sliding_window: bool,
+ pub hidden_act: Activation,
+ // MoE specific configuration
+ pub decoder_sparse_step: usize,
+ pub moe_intermediate_size: usize,
+ pub num_experts_per_tok: usize,
+ pub num_experts: usize,
+ pub norm_topk_prob: bool,
+}
+
+impl From<&Config> for Qwen3Config {
+ fn from(val: &Config) -> Self {
+ Qwen3Config {
+ vocab_size: val.vocab_size,
+ hidden_size: val.hidden_size,
+ intermediate_size: val.intermediate_size,
+ num_hidden_layers: val.num_hidden_layers,
+ num_attention_heads: val.num_attention_heads,
+ head_dim: val.head_dim,
+ attention_bias: val.attention_bias,
+ num_key_value_heads: val.num_key_value_heads,
+ max_position_embeddings: val.max_position_embeddings,
+ sliding_window: val.sliding_window,
+ max_window_layers: val.max_window_layers,
+ tie_word_embeddings: val.tie_word_embeddings,
+ rope_theta: val.rope_theta,
+ rms_norm_eps: val.rms_norm_eps,
+ use_sliding_window: val.use_sliding_window,
+ hidden_act: val.hidden_act,
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Qwen3MLPExpert {
+ gate_proj: Linear,
+ up_proj: Linear,
+ down_proj: Linear,
+ act_fn: Activation,
+}
+
+impl Qwen3MLPExpert {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result {
+ Ok(Self {
+ gate_proj: linear_no_bias(
+ cfg.hidden_size,
+ cfg.moe_intermediate_size,
+ vb.pp("gate_proj"),
+ )?,
+ up_proj: linear_no_bias(cfg.hidden_size, cfg.moe_intermediate_size, vb.pp("up_proj"))?,
+ down_proj: linear_no_bias(
+ cfg.moe_intermediate_size,
+ cfg.hidden_size,
+ vb.pp("down_proj"),
+ )?,
+ act_fn: cfg.hidden_act,
+ })
+ }
+}
+
+impl Module for Qwen3MLPExpert {
+ fn forward(&self, x: &Tensor) -> Result {
+ let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?;
+ let rhs = x.apply(&self.up_proj)?;
+ (lhs * rhs)?.apply(&self.down_proj)
+ }
+}
+
+// Qwen3 Sparse MoE Block implementation
+#[derive(Debug, Clone)]
+struct Qwen3SparseMoeBlock {
+ gate: Linear,
+ experts: Vec,
+ norm_topk_prob: bool,
+ num_experts_per_tok: usize,
+}
+
+impl Qwen3SparseMoeBlock {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result {
+ let gate = linear_no_bias(cfg.hidden_size, cfg.num_experts, vb.pp("gate"))?;
+ let mut experts = Vec::with_capacity(cfg.num_experts);
+ let vb_e = vb.pp("experts");
+ for idx in 0..cfg.num_experts {
+ let expert = Qwen3MLPExpert::new(cfg, vb_e.pp(idx))?;
+ experts.push(expert)
+ }
+ Ok(Self {
+ gate,
+ experts,
+ norm_topk_prob: cfg.norm_topk_prob,
+ num_experts_per_tok: cfg.num_experts_per_tok,
+ })
+ }
+}
+
+impl Module for Qwen3SparseMoeBlock {
+ fn forward(&self, xs: &Tensor) -> Result {
+ let (b_size, seq_len, hidden_dim) = xs.dims3()?;
+ let xs = xs.reshape(((), hidden_dim))?;
+ let router_logits = xs.apply(&self.gate)?;
+ let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
+
+ // Extract topk experts per token
+ let experts_per_tok = routing_weights
+ .arg_sort_last_dim(false)?
+ .narrow(D::Minus1, 0, self.num_experts_per_tok)?
+ .contiguous()?;
+ let routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?;
+
+ // Extract needed data
+ let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::()?;
+ let experts_per_tok = experts_per_tok.to_vec2::()?;
+ let mut top_x = vec![vec![]; self.experts.len()];
+ let mut selected_experts = vec![vec![]; self.experts.len()];
+ for (row_idx, (rw, expert_idxs)) in routing_weights
+ .iter()
+ .zip(experts_per_tok.iter())
+ .enumerate()
+ {
+ let sum_rw = rw.iter().sum::();
+ for (&rw, &expert_idx) in rw.iter().zip(expert_idxs.iter()) {
+ top_x[expert_idx as usize].push(row_idx as u32);
+ let rw = if self.norm_topk_prob { rw / sum_rw } else { rw };
+ selected_experts[expert_idx as usize].push(rw)
+ }
+ }
+
+ // Process through experts
+ let mut ys = xs.zeros_like()?;
+ for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
+ let top_x = &top_x[expert_idx];
+ if top_x.is_empty() {
+ continue;
+ }
+ let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
+ let selected_experts =
+ Tensor::new(selected_experts[expert_idx].as_slice(), xs.device())?
+ .reshape(((), 1))?
+ .to_dtype(xs.dtype())?;
+
+ let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
+ let current_hidden_states = expert_layer.forward(¤t_state)?;
+ let current_hidden_states = current_hidden_states.broadcast_mul(&selected_experts)?;
+ ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
+ }
+
+ ys.reshape((b_size, seq_len, hidden_dim))
+ }
+}
+
+// MLP or MoE decision enum
+#[derive(Debug, Clone)]
+enum Qwen3FeedForward {
+ Mlp(Qwen3MLP),
+ MoE(Qwen3SparseMoeBlock),
+}
+
+impl Module for Qwen3FeedForward {
+ fn forward(&self, xs: &Tensor) -> Result {
+ match self {
+ Self::Mlp(m) => m.forward(xs),
+ Self::MoE(m) => m.forward(xs),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+struct DecoderLayer {
+ self_attn: Qwen3Attention,
+ feed_forward: Qwen3FeedForward,
+ ln1: RmsNorm,
+ ln2: RmsNorm,
+}
+
+impl DecoderLayer {
+ fn new(
+ layer_idx: usize,
+ cfg: &Config,
+ rotary: Arc,
+ vb: VarBuilder,
+ ) -> Result {
+ let self_attn = Qwen3Attention::new(&cfg.into(), rotary, vb.pp("self_attn"))?;
+
+ // Decide whether to use MoE or regular MLP based on layer_idx and decoder_sparse_step
+ let feed_forward = if cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0
+ {
+ Qwen3FeedForward::MoE(Qwen3SparseMoeBlock::new(cfg, vb.pp("mlp"))?)
+ } else {
+ Qwen3FeedForward::Mlp(Qwen3MLP::new(&cfg.into(), vb.pp("mlp"))?)
+ };
+
+ let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
+ let ln2 = RmsNorm::new(
+ cfg.hidden_size,
+ cfg.rms_norm_eps,
+ vb.pp("post_attention_layernorm"),
+ )?;
+
+ Ok(Self {
+ self_attn,
+ feed_forward,
+ ln1,
+ ln2,
+ })
+ }
+
+ fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result {
+ let h = self.ln1.forward(x)?;
+ let h = self.self_attn.forward(&h, mask, offset)?;
+ let x = (x + h)?;
+ let h2 = self.ln2.forward(&x)?;
+ let h2 = h2.apply(&self.feed_forward)?;
+ x + h2
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.self_attn.clear_kv_cache();
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Model {
+ embed_tokens: candle_nn::Embedding,
+ layers: Vec,
+ norm: RmsNorm,
+ device: Device,
+ dtype: DType,
+}
+
+impl Model {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result {
+ let embed_tokens =
+ candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
+ let rotary = Arc::new(Qwen3RotaryEmbedding::new(
+ vb.dtype(),
+ &cfg.into(),
+ vb.device(),
+ )?);
+ let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
+ let vb_l = vb.pp("model.layers");
+ for i in 0..cfg.num_hidden_layers {
+ layers.push(DecoderLayer::new(i, cfg, rotary.clone(), vb_l.pp(i))?);
+ }
+ Ok(Self {
+ embed_tokens,
+ layers,
+ norm: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?,
+ device: vb.device().clone(),
+ dtype: vb.dtype(),
+ })
+ }
+
+ fn clear_kv_cache(&mut self) {
+ for l in &mut self.layers {
+ l.clear_kv_cache();
+ }
+ }
+
+ fn causal_mask(
+ &self,
+ b: usize,
+ tgt: usize,
+ offset: usize,
+ sw: Option,
+ ) -> Result {
+ let minf = f32::NEG_INFINITY;
+ let mask: Vec<_> = (0..tgt)
+ .flat_map(|i| {
+ (0..(tgt + offset)).map(move |j| {
+ let past_ok = j <= i + offset;
+ let sw_ok = match sw {
+ Some(w) => (i + offset) as i64 - j as i64 <= w as i64,
+ None => true,
+ };
+ if past_ok && sw_ok {
+ 0.
+ } else {
+ minf
+ }
+ })
+ })
+ .collect();
+ Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
+ }
+
+ pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result {
+ let (b, l) = input.dims2()?;
+ let mut h = self.embed_tokens.forward(input)?;
+
+ let causal = if l == 1 {
+ None
+ } else {
+ Some(self.causal_mask(b, l, offset, None)?)
+ };
+
+ for layer in &mut self.layers {
+ h = layer.forward(&h, causal.as_ref(), offset)?;
+ }
+ self.norm.forward(&h)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct ModelForCausalLM {
+ base: Model,
+ lm_head: Linear,
+}
+
+impl ModelForCausalLM {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result {
+ let base = Model::new(cfg, vb.clone())?;
+ let lm_head = if cfg.tie_word_embeddings {
+ Linear::from_weights(base.embed_tokens.embeddings().clone(), None)
+ } else {
+ linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
+ };
+ Ok(Self { base, lm_head })
+ }
+
+ pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result {
+ let (_, l) = input.dims2()?;
+ self.base
+ .forward(input, offset)?
+ .narrow(1, l - 1, 1)?
+ .apply(&self.lm_head)
+ }
+
+ pub fn clear_kv_cache(&mut self) {
+ self.base.clear_kv_cache();
+ }
+}