mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Rework the MLP bit.
This commit is contained in:
@ -29,13 +29,13 @@ impl QMatMul {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
struct MlpSilu {
|
||||
feed_forward_w1: QMatMul,
|
||||
feed_forward_w2: QMatMul,
|
||||
feed_forward_w3: QMatMul,
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
impl Module for MlpSilu {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let w1 = self.feed_forward_w1.forward(xs)?;
|
||||
let w3 = self.feed_forward_w3.forward(xs)?;
|
||||
@ -45,16 +45,31 @@ impl Module for Mlp {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum MlpOrMoe {
|
||||
Mlp(Mlp),
|
||||
struct MlpSimple {
|
||||
fc1: QMatMul,
|
||||
fc2: QMatMul,
|
||||
act: candle_nn::Activation,
|
||||
}
|
||||
|
||||
impl Module for MlpSimple {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.fc1.forward(xs)?.apply(&self.act)?;
|
||||
self.fc2.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum Mlp {
|
||||
Silu(MlpSilu),
|
||||
Simple(MlpSimple),
|
||||
MoE {
|
||||
n_expert_used: usize,
|
||||
feed_forward_gate_inp: QMatMul,
|
||||
experts: Vec<Mlp>,
|
||||
experts: Vec<MlpSilu>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Module for MlpOrMoe {
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::MoE {
|
||||
@ -119,7 +134,8 @@ impl Module for MlpOrMoe {
|
||||
let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
|
||||
Ok(ys)
|
||||
}
|
||||
Self::Mlp(mlp) => mlp.forward(xs),
|
||||
Self::Silu(mlp) => mlp.forward(xs),
|
||||
Self::Simple(mlp) => mlp.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -131,7 +147,7 @@ struct LayerWeights {
|
||||
attention_wv: QMatMul,
|
||||
attention_wo: QMatMul,
|
||||
attention_norm: RmsNorm,
|
||||
mlp_or_moe: MlpOrMoe,
|
||||
mlp: Mlp,
|
||||
ffn_norm: RmsNorm,
|
||||
n_head: usize,
|
||||
n_kv_head: usize,
|
||||
@ -324,11 +340,11 @@ impl ModelWeights {
|
||||
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
|
||||
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
|
||||
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
|
||||
let mlp_or_moe = {
|
||||
let mlp = {
|
||||
let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
|
||||
let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
|
||||
let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
|
||||
MlpOrMoe::Mlp(Mlp {
|
||||
Mlp::Silu(MlpSilu {
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
@ -345,7 +361,7 @@ impl ModelWeights {
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||
attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?,
|
||||
mlp_or_moe,
|
||||
mlp,
|
||||
ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?,
|
||||
n_head: ct.hparams.n_head as usize,
|
||||
n_kv_head: ct.hparams.n_head as usize / gqa,
|
||||
@ -397,14 +413,14 @@ impl ModelWeights {
|
||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
|
||||
let attention_wo =
|
||||
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
|
||||
let mlp_or_moe = if cfg.n_expert <= 1 {
|
||||
let mlp = if cfg.n_expert <= 1 {
|
||||
let feed_forward_w1 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
||||
let feed_forward_w2 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||
let feed_forward_w3 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||
MlpOrMoe::Mlp(Mlp {
|
||||
Mlp::Silu(MlpSilu {
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
@ -420,13 +436,13 @@ impl ModelWeights {
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?;
|
||||
let feed_forward_w3 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?;
|
||||
experts.push(Mlp {
|
||||
experts.push(MlpSilu {
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
})
|
||||
}
|
||||
MlpOrMoe::MoE {
|
||||
Mlp::MoE {
|
||||
n_expert_used: cfg.n_expert_used,
|
||||
feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
|
||||
experts,
|
||||
@ -444,7 +460,7 @@ impl ModelWeights {
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||
attention_norm: RmsNorm::from_qtensor(attention_norm, cfg.rms_norm_eps)?,
|
||||
mlp_or_moe,
|
||||
mlp,
|
||||
ffn_norm: RmsNorm::from_qtensor(ffn_norm, cfg.rms_norm_eps)?,
|
||||
n_head: cfg.head_count,
|
||||
n_kv_head: cfg.head_count_kv,
|
||||
@ -504,7 +520,7 @@ impl ModelWeights {
|
||||
let _enter = layer.span_mlp.enter();
|
||||
let residual = &x;
|
||||
let x = layer.ffn_norm.forward(&x)?;
|
||||
let x = layer.mlp_or_moe.forward(&x)?;
|
||||
let x = layer.mlp.forward(&x)?;
|
||||
let x = (x + residual)?;
|
||||
layer_in = x
|
||||
}
|
||||
|
Reference in New Issue
Block a user