diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 5e163cb6..bac9e7e7 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -29,13 +29,13 @@ impl QMatMul { } #[derive(Debug, Clone)] -struct Mlp { +struct MlpSilu { feed_forward_w1: QMatMul, feed_forward_w2: QMatMul, feed_forward_w3: QMatMul, } -impl Module for Mlp { +impl Module for MlpSilu { fn forward(&self, xs: &Tensor) -> Result { let w1 = self.feed_forward_w1.forward(xs)?; let w3 = self.feed_forward_w3.forward(xs)?; @@ -45,16 +45,31 @@ impl Module for Mlp { } #[derive(Debug, Clone)] -enum MlpOrMoe { - Mlp(Mlp), +struct MlpSimple { + fc1: QMatMul, + fc2: QMatMul, + act: candle_nn::Activation, +} + +impl Module for MlpSimple { + fn forward(&self, xs: &Tensor) -> Result { + let xs = self.fc1.forward(xs)?.apply(&self.act)?; + self.fc2.forward(&xs) + } +} + +#[derive(Debug, Clone)] +enum Mlp { + Silu(MlpSilu), + Simple(MlpSimple), MoE { n_expert_used: usize, feed_forward_gate_inp: QMatMul, - experts: Vec, + experts: Vec, }, } -impl Module for MlpOrMoe { +impl Module for Mlp { fn forward(&self, xs: &Tensor) -> Result { match self { Self::MoE { @@ -119,7 +134,8 @@ impl Module for MlpOrMoe { let ys = ys.reshape((b_size, seq_len, hidden_dim))?; Ok(ys) } - Self::Mlp(mlp) => mlp.forward(xs), + Self::Silu(mlp) => mlp.forward(xs), + Self::Simple(mlp) => mlp.forward(xs), } } } @@ -131,7 +147,7 @@ struct LayerWeights { attention_wv: QMatMul, attention_wo: QMatMul, attention_norm: RmsNorm, - mlp_or_moe: MlpOrMoe, + mlp: Mlp, ffn_norm: RmsNorm, n_head: usize, n_kv_head: usize, @@ -324,11 +340,11 @@ impl ModelWeights { let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?; let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?; let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?; - let mlp_or_moe = { + let mlp = { let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; - MlpOrMoe::Mlp(Mlp { + Mlp::Silu(MlpSilu { feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, @@ -345,7 +361,7 @@ impl ModelWeights { attention_wv: QMatMul::from_qtensor(attention_wv)?, attention_wo: QMatMul::from_qtensor(attention_wo)?, attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?, - mlp_or_moe, + mlp, ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?, n_head: ct.hparams.n_head as usize, n_kv_head: ct.hparams.n_head as usize / gqa, @@ -397,14 +413,14 @@ impl ModelWeights { let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; - let mlp_or_moe = if cfg.n_expert <= 1 { + let mlp = if cfg.n_expert <= 1 { let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; - MlpOrMoe::Mlp(Mlp { + Mlp::Silu(MlpSilu { feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, @@ -420,13 +436,13 @@ impl ModelWeights { ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?; let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?; - experts.push(Mlp { + experts.push(MlpSilu { feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, }) } - MlpOrMoe::MoE { + Mlp::MoE { n_expert_used: cfg.n_expert_used, feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?, experts, @@ -444,7 +460,7 @@ impl ModelWeights { attention_wv: QMatMul::from_qtensor(attention_wv)?, attention_wo: QMatMul::from_qtensor(attention_wo)?, attention_norm: RmsNorm::from_qtensor(attention_norm, cfg.rms_norm_eps)?, - mlp_or_moe, + mlp, ffn_norm: RmsNorm::from_qtensor(ffn_norm, cfg.rms_norm_eps)?, n_head: cfg.head_count, n_kv_head: cfg.head_count_kv, @@ -504,7 +520,7 @@ impl ModelWeights { let _enter = layer.span_mlp.enter(); let residual = &x; let x = layer.ffn_norm.forward(&x)?; - let x = layer.mlp_or_moe.forward(&x)?; + let x = layer.mlp.forward(&x)?; let x = (x + residual)?; layer_in = x }