diff --git a/candle-examples/examples/qwen/README.md b/candle-examples/examples/qwen/README.md index cb785f21..d81cd666 100644 --- a/candle-examples/examples/qwen/README.md +++ b/candle-examples/examples/qwen/README.md @@ -25,3 +25,28 @@ def print_prime(n: int): # n is the number of primes to be printed print(i) ``` +The qwen3 MoE variant is also an option. + +```bash +$ cargo run --example qwen --features metal --release -- --prompt "Write a poem about butterflies. ." --model "3-moe-a3b" +> In morning's hush, where daisies sleep, +> A fleeting dance through sunlit deep— +> They flutter soft on gossamer thread, +> The messengers of spring’s own head. +> +> With painted sails and delicate grace, +> They drift from bloom to blossom's face. +> Each wing a tale in hues unseen, +> Of ancient dreams and secrets between. +> +> No sound they make, yet still they speak— +> Of time that flies, of life so brief. +> A fleeting kiss on summer’s breath, +> A whisper lost before death. +> +> Yet in their flight, the soul takes wing, +> And for a moment, all is spring. +> For though they fade, they never die— +> Their beauty lives where hearts can fly. +> 161 tokens generated (3.00 token/s) +``` diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs index d0e179e0..3b90b9fb 100644 --- a/candle-examples/examples/qwen/main.rs +++ b/candle-examples/examples/qwen/main.rs @@ -10,6 +10,7 @@ use clap::Parser; use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase}; use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe}; use candle_transformers::models::qwen3::{Config as Config3, ModelForCausalLM as Model3}; +use candle_transformers::models::qwen3_moe::{Config as ConfigMoe3, ModelForCausalLM as ModelMoe3}; use candle::{DType, Device, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; @@ -22,6 +23,7 @@ enum Model { Base(ModelBase), Moe(ModelMoe), Base3(Model3), + Moe3(ModelMoe3), } impl Model { @@ -30,6 +32,7 @@ impl Model { Self::Moe(ref mut m) => m.forward(xs, s), Self::Base(ref mut m) => m.forward(xs, s), Self::Base3(ref mut m) => m.forward(xs, s), + Self::Moe3(ref mut m) => m.forward(xs, s), } } } @@ -167,6 +170,8 @@ enum WhichModel { W3_4b, #[value(name = "3-8b")] W3_8b, + #[value(name = "3-moe-a3b")] + W3MoeA3b, } #[derive(Parser, Debug)] @@ -273,6 +278,7 @@ fn main() -> Result<()> { WhichModel::W3_1_7b => ("3", "1.7B"), WhichModel::W3_4b => ("3", "4B"), WhichModel::W3_8b => ("3", "8B"), + WhichModel::W3MoeA3b => ("3", "30B-A3B"), }; format!("Qwen/Qwen{version}-{size}") } @@ -308,7 +314,8 @@ fn main() -> Result<()> { | WhichModel::MoeA27b | WhichModel::W3_1_7b | WhichModel::W3_4b - | WhichModel::W3_8b => { + | WhichModel::W3_8b + | WhichModel::W3MoeA3b => { candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? } }, @@ -334,6 +341,10 @@ fn main() -> Result<()> { let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?; Model::Base3(Model3::new(&config, vb)?) } + WhichModel::W3MoeA3b => { + let config: ConfigMoe3 = serde_json::from_slice(&std::fs::read(config_file)?)?; + Model::Moe3(ModelMoe3::new(&config, vb)?) + } _ => { let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?; Model::Base(ModelBase::new(&config, vb)?) diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index d8f71b44..8d80b183 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -100,6 +100,7 @@ pub mod quantized_t5; pub mod qwen2; pub mod qwen2_moe; pub mod qwen3; +pub mod qwen3_moe; pub mod recurrent_gemma; pub mod repvgg; pub mod resnet; diff --git a/candle-transformers/src/models/qwen3_moe.rs b/candle-transformers/src/models/qwen3_moe.rs new file mode 100644 index 00000000..e88a0538 --- /dev/null +++ b/candle-transformers/src/models/qwen3_moe.rs @@ -0,0 +1,355 @@ +use crate::models::{ + qwen3::{Config as Qwen3Config, Qwen3Attention, Qwen3MLP, Qwen3RotaryEmbedding}, + with_tracing::{linear_no_bias, Linear, RmsNorm}, +}; +use candle::{DType, Device, Module, Result, Tensor, D}; +use candle_nn::{Activation, VarBuilder}; +use std::sync::Arc; + +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +pub struct Config { + pub vocab_size: usize, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub head_dim: usize, + pub attention_bias: bool, + pub num_key_value_heads: usize, + pub max_position_embeddings: usize, + pub sliding_window: Option, + pub max_window_layers: usize, + pub tie_word_embeddings: bool, + pub rope_theta: f64, + pub rms_norm_eps: f64, + pub use_sliding_window: bool, + pub hidden_act: Activation, + // MoE specific configuration + pub decoder_sparse_step: usize, + pub moe_intermediate_size: usize, + pub num_experts_per_tok: usize, + pub num_experts: usize, + pub norm_topk_prob: bool, +} + +impl From<&Config> for Qwen3Config { + fn from(val: &Config) -> Self { + Qwen3Config { + vocab_size: val.vocab_size, + hidden_size: val.hidden_size, + intermediate_size: val.intermediate_size, + num_hidden_layers: val.num_hidden_layers, + num_attention_heads: val.num_attention_heads, + head_dim: val.head_dim, + attention_bias: val.attention_bias, + num_key_value_heads: val.num_key_value_heads, + max_position_embeddings: val.max_position_embeddings, + sliding_window: val.sliding_window, + max_window_layers: val.max_window_layers, + tie_word_embeddings: val.tie_word_embeddings, + rope_theta: val.rope_theta, + rms_norm_eps: val.rms_norm_eps, + use_sliding_window: val.use_sliding_window, + hidden_act: val.hidden_act, + } + } +} + +#[derive(Debug, Clone)] +struct Qwen3MLPExpert { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: Activation, +} + +impl Qwen3MLPExpert { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + Ok(Self { + gate_proj: linear_no_bias( + cfg.hidden_size, + cfg.moe_intermediate_size, + vb.pp("gate_proj"), + )?, + up_proj: linear_no_bias(cfg.hidden_size, cfg.moe_intermediate_size, vb.pp("up_proj"))?, + down_proj: linear_no_bias( + cfg.moe_intermediate_size, + cfg.hidden_size, + vb.pp("down_proj"), + )?, + act_fn: cfg.hidden_act, + }) + } +} + +impl Module for Qwen3MLPExpert { + fn forward(&self, x: &Tensor) -> Result { + let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?; + let rhs = x.apply(&self.up_proj)?; + (lhs * rhs)?.apply(&self.down_proj) + } +} + +// Qwen3 Sparse MoE Block implementation +#[derive(Debug, Clone)] +struct Qwen3SparseMoeBlock { + gate: Linear, + experts: Vec, + norm_topk_prob: bool, + num_experts_per_tok: usize, +} + +impl Qwen3SparseMoeBlock { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let gate = linear_no_bias(cfg.hidden_size, cfg.num_experts, vb.pp("gate"))?; + let mut experts = Vec::with_capacity(cfg.num_experts); + let vb_e = vb.pp("experts"); + for idx in 0..cfg.num_experts { + let expert = Qwen3MLPExpert::new(cfg, vb_e.pp(idx))?; + experts.push(expert) + } + Ok(Self { + gate, + experts, + norm_topk_prob: cfg.norm_topk_prob, + num_experts_per_tok: cfg.num_experts_per_tok, + }) + } +} + +impl Module for Qwen3SparseMoeBlock { + fn forward(&self, xs: &Tensor) -> Result { + let (b_size, seq_len, hidden_dim) = xs.dims3()?; + let xs = xs.reshape(((), hidden_dim))?; + let router_logits = xs.apply(&self.gate)?; + let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?; + + // Extract topk experts per token + let experts_per_tok = routing_weights + .arg_sort_last_dim(false)? + .narrow(D::Minus1, 0, self.num_experts_per_tok)? + .contiguous()?; + let routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?; + + // Extract needed data + let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::()?; + let experts_per_tok = experts_per_tok.to_vec2::()?; + let mut top_x = vec![vec![]; self.experts.len()]; + let mut selected_experts = vec![vec![]; self.experts.len()]; + for (row_idx, (rw, expert_idxs)) in routing_weights + .iter() + .zip(experts_per_tok.iter()) + .enumerate() + { + let sum_rw = rw.iter().sum::(); + for (&rw, &expert_idx) in rw.iter().zip(expert_idxs.iter()) { + top_x[expert_idx as usize].push(row_idx as u32); + let rw = if self.norm_topk_prob { rw / sum_rw } else { rw }; + selected_experts[expert_idx as usize].push(rw) + } + } + + // Process through experts + let mut ys = xs.zeros_like()?; + for (expert_idx, expert_layer) in self.experts.iter().enumerate() { + let top_x = &top_x[expert_idx]; + if top_x.is_empty() { + continue; + } + let top_x = Tensor::new(top_x.as_slice(), xs.device())?; + let selected_experts = + Tensor::new(selected_experts[expert_idx].as_slice(), xs.device())? + .reshape(((), 1))? + .to_dtype(xs.dtype())?; + + let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?; + let current_hidden_states = expert_layer.forward(¤t_state)?; + let current_hidden_states = current_hidden_states.broadcast_mul(&selected_experts)?; + ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?; + } + + ys.reshape((b_size, seq_len, hidden_dim)) + } +} + +// MLP or MoE decision enum +#[derive(Debug, Clone)] +enum Qwen3FeedForward { + Mlp(Qwen3MLP), + MoE(Qwen3SparseMoeBlock), +} + +impl Module for Qwen3FeedForward { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::Mlp(m) => m.forward(xs), + Self::MoE(m) => m.forward(xs), + } + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: Qwen3Attention, + feed_forward: Qwen3FeedForward, + ln1: RmsNorm, + ln2: RmsNorm, +} + +impl DecoderLayer { + fn new( + layer_idx: usize, + cfg: &Config, + rotary: Arc, + vb: VarBuilder, + ) -> Result { + let self_attn = Qwen3Attention::new(&cfg.into(), rotary, vb.pp("self_attn"))?; + + // Decide whether to use MoE or regular MLP based on layer_idx and decoder_sparse_step + let feed_forward = if cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0 + { + Qwen3FeedForward::MoE(Qwen3SparseMoeBlock::new(cfg, vb.pp("mlp"))?) + } else { + Qwen3FeedForward::Mlp(Qwen3MLP::new(&cfg.into(), vb.pp("mlp"))?) + }; + + let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let ln2 = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + + Ok(Self { + self_attn, + feed_forward, + ln1, + ln2, + }) + } + + fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result { + let h = self.ln1.forward(x)?; + let h = self.self_attn.forward(&h, mask, offset)?; + let x = (x + h)?; + let h2 = self.ln2.forward(&x)?; + let h2 = h2.apply(&self.feed_forward)?; + x + h2 + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache(); + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embed_tokens: candle_nn::Embedding, + layers: Vec, + norm: RmsNorm, + device: Device, + dtype: DType, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; + let rotary = Arc::new(Qwen3RotaryEmbedding::new( + vb.dtype(), + &cfg.into(), + vb.device(), + )?); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb.pp("model.layers"); + for i in 0..cfg.num_hidden_layers { + layers.push(DecoderLayer::new(i, cfg, rotary.clone(), vb_l.pp(i))?); + } + Ok(Self { + embed_tokens, + layers, + norm: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + fn clear_kv_cache(&mut self) { + for l in &mut self.layers { + l.clear_kv_cache(); + } + } + + fn causal_mask( + &self, + b: usize, + tgt: usize, + offset: usize, + sw: Option, + ) -> Result { + let minf = f32::NEG_INFINITY; + let mask: Vec<_> = (0..tgt) + .flat_map(|i| { + (0..(tgt + offset)).map(move |j| { + let past_ok = j <= i + offset; + let sw_ok = match sw { + Some(w) => (i + offset) as i64 - j as i64 <= w as i64, + None => true, + }; + if past_ok && sw_ok { + 0. + } else { + minf + } + }) + }) + .collect(); + Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype) + } + + pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result { + let (b, l) = input.dims2()?; + let mut h = self.embed_tokens.forward(input)?; + + let causal = if l == 1 { + None + } else { + Some(self.causal_mask(b, l, offset, None)?) + }; + + for layer in &mut self.layers { + h = layer.forward(&h, causal.as_ref(), offset)?; + } + self.norm.forward(&h) + } +} + +#[derive(Debug, Clone)] +pub struct ModelForCausalLM { + base: Model, + lm_head: Linear, +} + +impl ModelForCausalLM { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let base = Model::new(cfg, vb.clone())?; + let lm_head = if cfg.tie_word_embeddings { + Linear::from_weights(base.embed_tokens.embeddings().clone(), None) + } else { + linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? + }; + Ok(Self { base, lm_head }) + } + + pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result { + let (_, l) = input.dims2()?; + self.base + .forward(input, offset)? + .narrow(1, l - 1, 1)? + .apply(&self.lm_head) + } + + pub fn clear_kv_cache(&mut self) { + self.base.clear_kv_cache(); + } +}