//! Qwen3 implementation with quantization support. //! //! Based on the Qwen3 architecture and implemented with quantized weights //! for reduced memory usage and faster inference on compatible hardware. //! //! References: //! - [Qwen3 Models](https://huggingface.co/Qwen/Qwen3-0.6B) (architecture based on official implementations) //! use super::with_tracing::QMatMul; use crate::{quantized_nn::RmsNorm, utils::repeat_kv}; use candle::quantized::{gguf_file, QTensor}; use candle::{DType, Device, Result, Tensor}; use candle_nn::{kv_cache::KvCache, Activation, Embedding, Module}; use std::io::{Read, Seek}; use std::sync::Arc; struct Gguf { ct: gguf_file::Content, reader: R, device: Device, } impl Gguf { fn new(ct: gguf_file::Content, reader: R, device: Device) -> Self { Self { ct, reader, device } } fn qmatmul(&mut self, name: &str) -> Result { let ws = self.ct.tensor(&mut self.reader, name, &self.device)?; QMatMul::from_weights(ws.into()) } fn rms_norm(&mut self, name: &str, eps: f64) -> Result { let ws = self.ct.tensor(&mut self.reader, name, &self.device)?; RmsNorm::from_qtensor(ws, eps) } fn metadata(&self) -> &std::collections::HashMap { &self.ct.metadata } fn tensor(&mut self, name: &str) -> Result { self.ct.tensor(&mut self.reader, name, &self.device) } } #[derive(Debug, Clone)] struct MlpWeights { gate_proj: QMatMul, up_proj: QMatMul, down_proj: QMatMul, act_fn: Activation, span: tracing::Span, } impl MlpWeights { fn new(gg: &mut Gguf, prefix: &str) -> Result { let gate_proj = gg.qmatmul(&format!("{prefix}.ffn_gate.weight"))?; let up_proj = gg.qmatmul(&format!("{prefix}.ffn_up.weight"))?; let down_proj = gg.qmatmul(&format!("{prefix}.ffn_down.weight"))?; let act_fn = Activation::Silu; let span = tracing::span!(tracing::Level::TRACE, "mlp"); Ok(Self { gate_proj, up_proj, down_proj, act_fn, span, }) } } impl Module for MlpWeights { fn forward(&self, x: &Tensor) -> Result { let _enter = self.span.enter(); let gate = self.gate_proj.forward(x)?.apply(&self.act_fn)?; let up = self.up_proj.forward(x)?; let gated = (gate * up)?; self.down_proj.forward(&gated) } } #[derive(Debug, Clone)] struct RotaryEmbedding { sin: Tensor, cos: Tensor, } impl RotaryEmbedding { fn new( dtype: DType, head_dim: usize, max_position_embeddings: usize, rope_theta: f64, dev: &Device, ) -> Result { let dim = head_dim; let max_seq_len = max_position_embeddings; let inv_freq: Vec<_> = (0..dim) .step_by(2) .map(|i| 1f32 / rope_theta.powf(i as f64 / dim as f64) as f32) .collect(); let inv_freq_len = inv_freq.len(); let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; let t = Tensor::arange(0u32, max_seq_len as u32, dev)? .to_dtype(dtype)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; Ok(Self { sin: freqs.sin()?, cos: freqs.cos()?, }) } /// Apply RoPE (q, k shape: B x H x L x D) fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> { let (_, _, seq_len, _) = q.dims4()?; let cos = self.cos.narrow(0, offset, seq_len)?.to_dtype(q.dtype())?; let sin = self.sin.narrow(0, offset, seq_len)?.to_dtype(q.dtype())?; let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; Ok((q_embed, k_embed)) } } #[derive(Debug, Clone)] struct AttentionWeights { q_proj: QMatMul, k_proj: QMatMul, v_proj: QMatMul, o_proj: QMatMul, q_norm: RmsNorm, k_norm: RmsNorm, num_heads: usize, num_kv_heads: usize, num_kv_groups: usize, head_dim: usize, rotary_emb: Arc, kv_cache: KvCache, span_attn: tracing::Span, } impl AttentionWeights { fn new( gg: &mut Gguf, num_heads: usize, num_kv_heads: usize, head_dim: usize, rms_norm_eps: f64, rotary_emb: Arc, prefix: &str, ) -> Result { let num_kv_groups = num_heads / num_kv_heads; let q_proj = gg.qmatmul(&format!("{prefix}.attn_q.weight"))?; let k_proj = gg.qmatmul(&format!("{prefix}.attn_k.weight"))?; let v_proj = gg.qmatmul(&format!("{prefix}.attn_v.weight"))?; let o_proj = gg.qmatmul(&format!("{prefix}.attn_output.weight"))?; let q_norm = gg.rms_norm(&format!("{prefix}.attn_q_norm.weight"), rms_norm_eps)?; let k_norm = gg.rms_norm(&format!("{prefix}.attn_k_norm.weight"), rms_norm_eps)?; // Initialize KV cache with 512 tokens capacity to reduce initial memory allocation. // The cache will grow in chunks of 512 tokens when needed. let kv_cache = KvCache::new(2, 512); let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); Ok(Self { q_proj, k_proj, v_proj, o_proj, q_norm, k_norm, num_heads, num_kv_heads, num_kv_groups, head_dim, rotary_emb, kv_cache, span_attn, }) } fn forward(&mut self, x: &Tensor, attn_mask: Option<&Tensor>, offset: usize) -> Result { let _enter = self.span_attn.enter(); let (b, l, _) = x.dims3()?; let q = self.q_proj.forward(x)?; let k = self.k_proj.forward(x)?; let v = self.v_proj.forward(x)?; let q = q .reshape((b, l, self.num_heads, self.head_dim))? .transpose(1, 2)?; let k = k .reshape((b, l, self.num_kv_heads, self.head_dim))? .transpose(1, 2)?; let v = v .reshape((b, l, self.num_kv_heads, self.head_dim))? .transpose(1, 2)?; let q_flat = q.flatten(0, 2)?; let k_flat = k.flatten(0, 2)?; let q_flat = self.q_norm.forward(&q_flat)?; let k_flat = self.k_norm.forward(&k_flat)?; let q = q_flat.reshape((b, self.num_heads, l, self.head_dim))?; let k = k_flat.reshape((b, self.num_kv_heads, l, self.head_dim))?; let (q, k) = self.rotary_emb.apply(&q, &k, offset)?; // Reset KV cache if we're at the first position if offset == 0 { self.kv_cache.reset(); } let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?; let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?; let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?; let scale = 1.0 / (self.head_dim as f64).sqrt(); let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?; if let Some(m) = attn_mask { let m_dtype = m.dtype(); let scores_dtype = scores.dtype(); let mask = if m_dtype != scores_dtype { m.to_dtype(scores_dtype)? } else { m.clone() }; scores = scores.broadcast_add(&mask)?; } let probs = candle_nn::ops::softmax_last_dim(&scores)?; let ctx = probs.matmul(&v)?; // (B, H, L, D) let reshaped_ctx = ctx .transpose(1, 2)? .reshape((b, l, self.num_heads * self.head_dim))?; self.o_proj.forward(&reshaped_ctx) } } #[derive(Debug, Clone)] struct LayerWeights { self_attn: AttentionWeights, mlp: MlpWeights, ln1: RmsNorm, ln2: RmsNorm, } impl LayerWeights { fn new( gg: &mut Gguf, num_attention_heads: usize, num_key_value_heads: usize, head_dim: usize, rms_norm_eps: f64, rotary: Arc, layer_idx: usize, ) -> Result { let prefix = format!("blk.{layer_idx}"); let ln1 = gg.rms_norm(&format!("{prefix}.attn_norm.weight"), rms_norm_eps)?; let ln2 = gg.rms_norm(&format!("{prefix}.ffn_norm.weight"), rms_norm_eps)?; let self_attn = AttentionWeights::new( gg, num_attention_heads, num_key_value_heads, head_dim, rms_norm_eps, rotary, &prefix, )?; let mlp = MlpWeights::new(gg, &prefix)?; Ok(Self { self_attn, mlp, ln1, ln2, }) } fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result { let h = self.ln1.forward(x)?; let h = self.self_attn.forward(&h, mask, offset)?; let x = (x + h)?; let h2 = self.ln2.forward(&x)?; let h2 = h2.apply(&self.mlp)?; x + h2 } } #[derive(Debug, Clone)] pub struct ModelWeights { embed_tokens: Embedding, layers: Vec, norm: RmsNorm, lm_head: QMatMul, device: Device, dtype: DType, span: tracing::Span, span_output: tracing::Span, } impl ModelWeights { pub fn from_gguf( ct: gguf_file::Content, reader: &mut R, device: &Device, ) -> Result { let mut gg = Gguf::new(ct, reader, device.clone()); let md_get = |s: &str| match gg.metadata().get(s) { None => candle::bail!("cannot find {s} in metadata"), Some(v) => Ok(v), }; let num_attention_heads = md_get("qwen3.attention.head_count")?.to_u32()? as usize; let num_kv_heads = md_get("qwen3.attention.head_count_kv")?.to_u32()? as usize; let head_dim = md_get("qwen3.attention.key_length")?.to_u32()? as usize; let num_layers = md_get("qwen3.block_count")?.to_u32()? as usize; let hidden_size = md_get("qwen3.embedding_length")?.to_u32()? as usize; let max_position_embeddings = md_get("qwen3.context_length")?.to_u32()? as usize; let rms_norm_eps = md_get("qwen3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; let rope_freq_base = md_get("qwen3.rope.freq_base")?.to_f32()? as f64; let dtype = match gg.metadata().get("general.dtype") { Some(v) => match v.to_u32() { Ok(0) => DType::F32, Ok(1) => DType::F16, _ => DType::F16, }, None => DType::F16, }; let embed_tensor = gg.tensor("token_embd.weight")?; let embed_tokens = Embedding::new(embed_tensor.dequantize(device)?, hidden_size); let rotary = Arc::new(RotaryEmbedding::new( dtype, head_dim, max_position_embeddings, rope_freq_base, device, )?); let mut layers = Vec::with_capacity(num_layers); for i in 0..num_layers { layers.push(LayerWeights::new( &mut gg, num_attention_heads, num_kv_heads, head_dim, rms_norm_eps, rotary.clone(), i, )?); } let norm = gg.rms_norm("output_norm.weight", rms_norm_eps)?; // Load output projection tensor, falling back to tied embeddings like gemma3 let lm_head_tensor = match gg.tensor("output.weight") { Ok(tensor) => tensor, Err(_) => gg.tensor("token_embd.weight")?, }; let lm_head = QMatMul::from_weights(lm_head_tensor.into())?; let span = tracing::span!(tracing::Level::TRACE, "model"); let span_output = tracing::span!(tracing::Level::TRACE, "output"); Ok(Self { embed_tokens, layers, norm, lm_head, device: device.clone(), dtype, span, span_output, }) } fn causal_mask( &self, b: usize, tgt: usize, offset: usize, sw: Option, ) -> Result { let minf = f32::NEG_INFINITY; let mask: Vec<_> = (0..tgt) .flat_map(|i| { (0..(tgt + offset)).map(move |j| { let past_ok = j <= i + offset; let sw_ok = match sw { Some(w) => (i + offset) as i64 - j as i64 <= w as i64, None => true, }; if past_ok && sw_ok { 0. } else { minf } }) }) .collect(); Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype) } pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result { let _enter = self.span.enter(); let (b, l) = input.dims2()?; let mut h = self.embed_tokens.forward(input)?; let causal_mask = if l == 1 { None } else { Some(self.causal_mask(b, l, offset, None)?) }; for layer in &mut self.layers { h = layer.forward(&h, causal_mask.as_ref(), offset)?; } let h = self.norm.forward(&h)?; let _enter = self.span_output.enter(); let last_hidden = h.narrow(1, l - 1, 1)?; self.lm_head.forward(&last_hidden)?.squeeze(1) } }