mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Rework the MLP bit.
This commit is contained in:
@ -29,13 +29,13 @@ impl QMatMul {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct Mlp {
|
struct MlpSilu {
|
||||||
feed_forward_w1: QMatMul,
|
feed_forward_w1: QMatMul,
|
||||||
feed_forward_w2: QMatMul,
|
feed_forward_w2: QMatMul,
|
||||||
feed_forward_w3: QMatMul,
|
feed_forward_w3: QMatMul,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for Mlp {
|
impl Module for MlpSilu {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let w1 = self.feed_forward_w1.forward(xs)?;
|
let w1 = self.feed_forward_w1.forward(xs)?;
|
||||||
let w3 = self.feed_forward_w3.forward(xs)?;
|
let w3 = self.feed_forward_w3.forward(xs)?;
|
||||||
@ -45,16 +45,31 @@ impl Module for Mlp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
enum MlpOrMoe {
|
struct MlpSimple {
|
||||||
Mlp(Mlp),
|
fc1: QMatMul,
|
||||||
|
fc2: QMatMul,
|
||||||
|
act: candle_nn::Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for MlpSimple {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let xs = self.fc1.forward(xs)?.apply(&self.act)?;
|
||||||
|
self.fc2.forward(&xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
enum Mlp {
|
||||||
|
Silu(MlpSilu),
|
||||||
|
Simple(MlpSimple),
|
||||||
MoE {
|
MoE {
|
||||||
n_expert_used: usize,
|
n_expert_used: usize,
|
||||||
feed_forward_gate_inp: QMatMul,
|
feed_forward_gate_inp: QMatMul,
|
||||||
experts: Vec<Mlp>,
|
experts: Vec<MlpSilu>,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for MlpOrMoe {
|
impl Module for Mlp {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Self::MoE {
|
Self::MoE {
|
||||||
@ -119,7 +134,8 @@ impl Module for MlpOrMoe {
|
|||||||
let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
|
let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
|
||||||
Ok(ys)
|
Ok(ys)
|
||||||
}
|
}
|
||||||
Self::Mlp(mlp) => mlp.forward(xs),
|
Self::Silu(mlp) => mlp.forward(xs),
|
||||||
|
Self::Simple(mlp) => mlp.forward(xs),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -131,7 +147,7 @@ struct LayerWeights {
|
|||||||
attention_wv: QMatMul,
|
attention_wv: QMatMul,
|
||||||
attention_wo: QMatMul,
|
attention_wo: QMatMul,
|
||||||
attention_norm: RmsNorm,
|
attention_norm: RmsNorm,
|
||||||
mlp_or_moe: MlpOrMoe,
|
mlp: Mlp,
|
||||||
ffn_norm: RmsNorm,
|
ffn_norm: RmsNorm,
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
n_kv_head: usize,
|
n_kv_head: usize,
|
||||||
@ -324,11 +340,11 @@ impl ModelWeights {
|
|||||||
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
|
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
|
||||||
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
|
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
|
||||||
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
|
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
|
||||||
let mlp_or_moe = {
|
let mlp = {
|
||||||
let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
|
let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
|
||||||
let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
|
let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
|
||||||
let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
|
let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
|
||||||
MlpOrMoe::Mlp(Mlp {
|
Mlp::Silu(MlpSilu {
|
||||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||||
@ -345,7 +361,7 @@ impl ModelWeights {
|
|||||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||||
attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?,
|
attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?,
|
||||||
mlp_or_moe,
|
mlp,
|
||||||
ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?,
|
ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?,
|
||||||
n_head: ct.hparams.n_head as usize,
|
n_head: ct.hparams.n_head as usize,
|
||||||
n_kv_head: ct.hparams.n_head as usize / gqa,
|
n_kv_head: ct.hparams.n_head as usize / gqa,
|
||||||
@ -397,14 +413,14 @@ impl ModelWeights {
|
|||||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
|
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
|
||||||
let attention_wo =
|
let attention_wo =
|
||||||
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
|
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
|
||||||
let mlp_or_moe = if cfg.n_expert <= 1 {
|
let mlp = if cfg.n_expert <= 1 {
|
||||||
let feed_forward_w1 =
|
let feed_forward_w1 =
|
||||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
||||||
let feed_forward_w2 =
|
let feed_forward_w2 =
|
||||||
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||||
let feed_forward_w3 =
|
let feed_forward_w3 =
|
||||||
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||||
MlpOrMoe::Mlp(Mlp {
|
Mlp::Silu(MlpSilu {
|
||||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||||
@ -420,13 +436,13 @@ impl ModelWeights {
|
|||||||
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?;
|
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?;
|
||||||
let feed_forward_w3 =
|
let feed_forward_w3 =
|
||||||
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?;
|
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?;
|
||||||
experts.push(Mlp {
|
experts.push(MlpSilu {
|
||||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
MlpOrMoe::MoE {
|
Mlp::MoE {
|
||||||
n_expert_used: cfg.n_expert_used,
|
n_expert_used: cfg.n_expert_used,
|
||||||
feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
|
feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
|
||||||
experts,
|
experts,
|
||||||
@ -444,7 +460,7 @@ impl ModelWeights {
|
|||||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||||
attention_norm: RmsNorm::from_qtensor(attention_norm, cfg.rms_norm_eps)?,
|
attention_norm: RmsNorm::from_qtensor(attention_norm, cfg.rms_norm_eps)?,
|
||||||
mlp_or_moe,
|
mlp,
|
||||||
ffn_norm: RmsNorm::from_qtensor(ffn_norm, cfg.rms_norm_eps)?,
|
ffn_norm: RmsNorm::from_qtensor(ffn_norm, cfg.rms_norm_eps)?,
|
||||||
n_head: cfg.head_count,
|
n_head: cfg.head_count,
|
||||||
n_kv_head: cfg.head_count_kv,
|
n_kv_head: cfg.head_count_kv,
|
||||||
@ -504,7 +520,7 @@ impl ModelWeights {
|
|||||||
let _enter = layer.span_mlp.enter();
|
let _enter = layer.span_mlp.enter();
|
||||||
let residual = &x;
|
let residual = &x;
|
||||||
let x = layer.ffn_norm.forward(&x)?;
|
let x = layer.ffn_norm.forward(&x)?;
|
||||||
let x = layer.mlp_or_moe.forward(&x)?;
|
let x = layer.mlp.forward(&x)?;
|
||||||
let x = (x + residual)?;
|
let x = (x + residual)?;
|
||||||
layer_in = x
|
layer_in = x
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user