Rework the MLP bit.

This commit is contained in:
laurent
2024-04-17 09:28:50 +02:00
parent af11b2d461
commit d79041d94d

View File

@ -29,13 +29,13 @@ impl QMatMul {
}
#[derive(Debug, Clone)]
struct Mlp {
struct MlpSilu {
feed_forward_w1: QMatMul,
feed_forward_w2: QMatMul,
feed_forward_w3: QMatMul,
}
impl Module for Mlp {
impl Module for MlpSilu {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let w1 = self.feed_forward_w1.forward(xs)?;
let w3 = self.feed_forward_w3.forward(xs)?;
@ -45,16 +45,31 @@ impl Module for Mlp {
}
#[derive(Debug, Clone)]
enum MlpOrMoe {
Mlp(Mlp),
struct MlpSimple {
fc1: QMatMul,
fc2: QMatMul,
act: candle_nn::Activation,
}
impl Module for MlpSimple {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.fc1.forward(xs)?.apply(&self.act)?;
self.fc2.forward(&xs)
}
}
#[derive(Debug, Clone)]
enum Mlp {
Silu(MlpSilu),
Simple(MlpSimple),
MoE {
n_expert_used: usize,
feed_forward_gate_inp: QMatMul,
experts: Vec<Mlp>,
experts: Vec<MlpSilu>,
},
}
impl Module for MlpOrMoe {
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::MoE {
@ -119,7 +134,8 @@ impl Module for MlpOrMoe {
let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
Ok(ys)
}
Self::Mlp(mlp) => mlp.forward(xs),
Self::Silu(mlp) => mlp.forward(xs),
Self::Simple(mlp) => mlp.forward(xs),
}
}
}
@ -131,7 +147,7 @@ struct LayerWeights {
attention_wv: QMatMul,
attention_wo: QMatMul,
attention_norm: RmsNorm,
mlp_or_moe: MlpOrMoe,
mlp: Mlp,
ffn_norm: RmsNorm,
n_head: usize,
n_kv_head: usize,
@ -324,11 +340,11 @@ impl ModelWeights {
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
let mlp_or_moe = {
let mlp = {
let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
MlpOrMoe::Mlp(Mlp {
Mlp::Silu(MlpSilu {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
@ -345,7 +361,7 @@ impl ModelWeights {
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?,
mlp_or_moe,
mlp,
ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?,
n_head: ct.hparams.n_head as usize,
n_kv_head: ct.hparams.n_head as usize / gqa,
@ -397,14 +413,14 @@ impl ModelWeights {
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
let attention_wo =
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
let mlp_or_moe = if cfg.n_expert <= 1 {
let mlp = if cfg.n_expert <= 1 {
let feed_forward_w1 =
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
let feed_forward_w2 =
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
let feed_forward_w3 =
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
MlpOrMoe::Mlp(Mlp {
Mlp::Silu(MlpSilu {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
@ -420,13 +436,13 @@ impl ModelWeights {
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?;
let feed_forward_w3 =
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?;
experts.push(Mlp {
experts.push(MlpSilu {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
})
}
MlpOrMoe::MoE {
Mlp::MoE {
n_expert_used: cfg.n_expert_used,
feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
experts,
@ -444,7 +460,7 @@ impl ModelWeights {
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::from_qtensor(attention_norm, cfg.rms_norm_eps)?,
mlp_or_moe,
mlp,
ffn_norm: RmsNorm::from_qtensor(ffn_norm, cfg.rms_norm_eps)?,
n_head: cfg.head_count,
n_kv_head: cfg.head_count_kv,
@ -504,7 +520,7 @@ impl ModelWeights {
let _enter = layer.span_mlp.enter();
let residual = &x;
let x = layer.ffn_norm.forward(&x)?;
let x = layer.mlp_or_moe.forward(&x)?;
let x = layer.mlp.forward(&x)?;
let x = (x + residual)?;
layer_in = x
}