diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index b21d6751..ee348de2 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -61,6 +61,8 @@ enum Which { OpenChat35, #[value(name = "7b-starling-a")] Starling7bAlpha, + #[value(name = "mixtral")] + Mixtral, } impl Which { @@ -83,6 +85,7 @@ impl Which { | Self::Starling7bAlpha | Self::Zephyr7bAlpha | Self::Zephyr7bBeta + | Self::Mixtral | Self::Mistral7b | Self::Mistral7bInstruct => true, } @@ -101,6 +104,7 @@ impl Which { | Self::L34bCode | Self::Leo7b | Self::Leo13b + | Self::Mixtral | Self::Mistral7b | Self::Mistral7bInstruct | Self::OpenChat35 @@ -122,6 +126,7 @@ impl Which { | Self::L34bCode | Self::Leo7b | Self::Leo13b + | Self::Mixtral | Self::Mistral7b | Self::Mistral7bInstruct | Self::Zephyr7bAlpha @@ -143,6 +148,7 @@ impl Which { | Which::L34bCode => "hf-internal-testing/llama-tokenizer", Which::Leo7b => "LeoLM/leo-hessianai-7b", Which::Leo13b => "LeoLM/leo-hessianai-13b", + Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1", Which::Mistral7b | Which::Mistral7bInstruct | Which::Zephyr7bAlpha @@ -256,6 +262,10 @@ impl Args { "TheBloke/leo-hessianai-13B-GGUF", "leo-hessianai-13b.Q4_K_M.gguf", ), + Which::Mixtral => ( + "TheBloke/Mixtral-8x7B-v0.1-GGUF", + "mixtral-8x7b-v0.1.Q4_K_M.gguf", + ), Which::Mistral7b => ( "TheBloke/Mistral-7B-v0.1-GGUF", "mistral-7b-v0.1.Q4_K_S.gguf", @@ -374,7 +384,8 @@ fn main() -> anyhow::Result<()> { | Which::L34bCode | Which::Leo7b | Which::Leo13b => 1, - Which::Mistral7b + Which::Mixtral + | Which::Mistral7b | Which::Mistral7bInstruct | Which::Zephyr7bAlpha | Which::Zephyr7bBeta diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 44d89f40..1fb2d9e2 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -47,6 +47,102 @@ impl QMatMul { } } +#[derive(Debug, Clone)] +struct Mlp { + feed_forward_w1: QMatMul, + feed_forward_w2: QMatMul, + feed_forward_w3: QMatMul, +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + let w1 = self.feed_forward_w1.forward(xs)?; + let w3 = self.feed_forward_w3.forward(xs)?; + self.feed_forward_w2 + .forward(&(candle_nn::ops::silu(&w1)? * w3)?) + } +} + +#[derive(Debug, Clone)] +enum MlpOrMoe { + Mlp(Mlp), + MoE { + n_expert_used: usize, + feed_forward_gate_inp: QMatMul, + experts: Vec, + }, +} + +impl Module for MlpOrMoe { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::MoE { + feed_forward_gate_inp, + experts, + n_expert_used, + } => { + let (b_size, seq_len, hidden_dim) = xs.dims3()?; + let xs = xs.reshape(((), hidden_dim))?; + let router_logits = feed_forward_gate_inp.forward(&xs)?; + let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?; + + // In order to extract topk, we extract the data from the tensor and manipulate it + // directly. Maybe we will want to use some custom ops instead at some point. + let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::()?; + + // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + // top_x contains the row indexes to evaluate for each expert. + let mut top_x = vec![vec![]; experts.len()]; + let mut selected_rws = vec![vec![]; experts.len()]; + for (row_idx, rw) in routing_weights.iter().enumerate() { + let mut dst = (0..rw.len() as u32).collect::>(); + dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize])); + let mut sum_routing_weights = 0f32; + for &expert_idx in dst.iter().take(*n_expert_used) { + let expert_idx = expert_idx as usize; + let routing_weight = rw[expert_idx]; + sum_routing_weights += routing_weight; + top_x[expert_idx].push(row_idx as u32); + } + for &expert_idx in dst.iter().take(*n_expert_used) { + let expert_idx = expert_idx as usize; + let routing_weight = rw[expert_idx]; + selected_rws[expert_idx].push(routing_weight / sum_routing_weights) + } + } + + // routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + let mut ys = xs.zeros_like()?; + for (expert_idx, expert_layer) in experts.iter().enumerate() { + let top_x = &top_x[expert_idx]; + if top_x.is_empty() { + continue; + } + let top_x = Tensor::new(top_x.as_slice(), xs.device())?; + let selected_rws = + Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())? + .reshape(((), 1))?; + // Index the correct hidden states and compute the expert hidden state for + // the current expert. We need to make sure to multiply the output hidden + // states by `routing_weights` on the corresponding tokens (top-1 and top-2) + let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?; + // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None]) + let current_hidden_states = expert_layer.forward(¤t_state)?; + let current_hidden_states = + current_hidden_states.broadcast_mul(&selected_rws)?; + ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?; + } + + let ys = ys.reshape((b_size, seq_len, hidden_dim))?; + Ok(ys) + } + Self::Mlp(mlp) => mlp.forward(xs), + } + } +} + #[derive(Debug, Clone)] struct LayerWeights { attention_wq: QMatMul, @@ -54,9 +150,7 @@ struct LayerWeights { attention_wv: QMatMul, attention_wo: QMatMul, attention_norm: RmsNorm, - feed_forward_w1: QMatMul, - feed_forward_w2: QMatMul, - feed_forward_w3: QMatMul, + mlp_or_moe: MlpOrMoe, ffn_norm: RmsNorm, n_head: usize, n_kv_head: usize, @@ -212,9 +306,16 @@ impl ModelWeights { let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?; let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?; let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?; - let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; - let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; - let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; + let mlp_or_moe = { + let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; + let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; + let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; + MlpOrMoe::Mlp(Mlp { + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + }) + }; let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?; let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?; let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); @@ -226,9 +327,7 @@ impl ModelWeights { attention_wv: QMatMul::from_qtensor(attention_wv)?, attention_wo: QMatMul::from_qtensor(attention_wo)?, attention_norm: RmsNorm::new(attention_norm, 1e-5)?, - feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, - feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, - feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + mlp_or_moe, ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?, n_head: ct.hparams.n_head as usize, n_kv_head: ct.hparams.n_head as usize / gqa, @@ -265,6 +364,12 @@ impl ModelWeights { }; // Parameter extraction from metadata. + let n_expert = md_get("llama.expert_count") + .and_then(|v| v.to_u32()) + .unwrap_or(0) as usize; + let n_expert_used = md_get("llama.expert_used_count") + .and_then(|v| v.to_u32()) + .unwrap_or(0) as usize; let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; let block_count = md_get("llama.block_count")?.to_u32()? as usize; @@ -289,9 +394,38 @@ impl ModelWeights { let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?; let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?; let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?; - let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?; - let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?; - let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?; + let mlp_or_moe = if n_expert <= 1 { + let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?; + let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?; + let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?; + MlpOrMoe::Mlp(Mlp { + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + }) + } else { + let feed_forward_gate_inp = + ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"))?; + let mut experts = Vec::with_capacity(n_expert); + for i in 0..n_expert { + let feed_forward_w1 = + ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"))?; + let feed_forward_w2 = + ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"))?; + let feed_forward_w3 = + ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"))?; + experts.push(Mlp { + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + }) + } + MlpOrMoe::MoE { + n_expert_used, + feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?, + experts, + } + }; let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?; let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?; let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); @@ -303,9 +437,7 @@ impl ModelWeights { attention_wv: QMatMul::from_qtensor(attention_wv)?, attention_wo: QMatMul::from_qtensor(attention_wo)?, attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?, - feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, - feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, - feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + mlp_or_moe, ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?, n_head: head_count, n_kv_head: head_count_kv, @@ -360,12 +492,9 @@ impl ModelWeights { let _enter = layer.span_mlp.enter(); let residual = &x; let x = layer.ffn_norm.forward(&x)?; - let w1 = layer.feed_forward_w1.forward(&x)?; - let w3 = layer.feed_forward_w3.forward(&x)?; - let mlp = layer - .feed_forward_w2 - .forward(&(candle_nn::ops::silu(&w1)? * w3)?)?; - layer_in = (mlp + residual)?; + let x = layer.mlp_or_moe.forward(&x)?; + let x = (x + residual)?; + layer_in = x } let x = self.norm.forward(&layer_in)?; let x = x.i((.., seq_len - 1, ..))?;