Quantized mixtral model (#1442)

* Add the Mixtral model.

* Add more of the mixtral layers.

* Add the final layers for mixtral.

* Sketch the expert selection.

* Add some expert routing logic.

* Hopefully finish the routing logic for mixtral.

* Add the mixtral example.

* Fix the weight filenames.

* Bugfix.

* Another fix.

* Yet another fix + remove the unused pragma.

* Shape fix.

* Support for quantized mixtral.

* Support mixtral in the quantized example.

* Mlp or moe type.

* Fix the expert field namings.

* Refactor the mlp bit.

* More MoE logic.

* Add the MoE quantized logic.

* Fix the experts length.
This commit is contained in:
Laurent Mazare
2023-12-15 19:16:06 -06:00
committed by GitHub
parent 614842b311
commit 30a958e5dd
2 changed files with 162 additions and 22 deletions

View File

@ -61,6 +61,8 @@ enum Which {
OpenChat35,
#[value(name = "7b-starling-a")]
Starling7bAlpha,
#[value(name = "mixtral")]
Mixtral,
}
impl Which {
@ -83,6 +85,7 @@ impl Which {
| Self::Starling7bAlpha
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta
| Self::Mixtral
| Self::Mistral7b
| Self::Mistral7bInstruct => true,
}
@ -101,6 +104,7 @@ impl Which {
| Self::L34bCode
| Self::Leo7b
| Self::Leo13b
| Self::Mixtral
| Self::Mistral7b
| Self::Mistral7bInstruct
| Self::OpenChat35
@ -122,6 +126,7 @@ impl Which {
| Self::L34bCode
| Self::Leo7b
| Self::Leo13b
| Self::Mixtral
| Self::Mistral7b
| Self::Mistral7bInstruct
| Self::Zephyr7bAlpha
@ -143,6 +148,7 @@ impl Which {
| Which::L34bCode => "hf-internal-testing/llama-tokenizer",
Which::Leo7b => "LeoLM/leo-hessianai-7b",
Which::Leo13b => "LeoLM/leo-hessianai-13b",
Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Zephyr7bAlpha
@ -256,6 +262,10 @@ impl Args {
"TheBloke/leo-hessianai-13B-GGUF",
"leo-hessianai-13b.Q4_K_M.gguf",
),
Which::Mixtral => (
"TheBloke/Mixtral-8x7B-v0.1-GGUF",
"mixtral-8x7b-v0.1.Q4_K_M.gguf",
),
Which::Mistral7b => (
"TheBloke/Mistral-7B-v0.1-GGUF",
"mistral-7b-v0.1.Q4_K_S.gguf",
@ -374,7 +384,8 @@ fn main() -> anyhow::Result<()> {
| Which::L34bCode
| Which::Leo7b
| Which::Leo13b => 1,
Which::Mistral7b
Which::Mixtral
| Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta

View File

@ -47,6 +47,102 @@ impl QMatMul {
}
}
#[derive(Debug, Clone)]
struct Mlp {
feed_forward_w1: QMatMul,
feed_forward_w2: QMatMul,
feed_forward_w3: QMatMul,
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let w1 = self.feed_forward_w1.forward(xs)?;
let w3 = self.feed_forward_w3.forward(xs)?;
self.feed_forward_w2
.forward(&(candle_nn::ops::silu(&w1)? * w3)?)
}
}
#[derive(Debug, Clone)]
enum MlpOrMoe {
Mlp(Mlp),
MoE {
n_expert_used: usize,
feed_forward_gate_inp: QMatMul,
experts: Vec<Mlp>,
},
}
impl Module for MlpOrMoe {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::MoE {
feed_forward_gate_inp,
experts,
n_expert_used,
} => {
let (b_size, seq_len, hidden_dim) = xs.dims3()?;
let xs = xs.reshape(((), hidden_dim))?;
let router_logits = feed_forward_gate_inp.forward(&xs)?;
let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
// In order to extract topk, we extract the data from the tensor and manipulate it
// directly. Maybe we will want to use some custom ops instead at some point.
let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
// routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
// top_x contains the row indexes to evaluate for each expert.
let mut top_x = vec![vec![]; experts.len()];
let mut selected_rws = vec![vec![]; experts.len()];
for (row_idx, rw) in routing_weights.iter().enumerate() {
let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
let mut sum_routing_weights = 0f32;
for &expert_idx in dst.iter().take(*n_expert_used) {
let expert_idx = expert_idx as usize;
let routing_weight = rw[expert_idx];
sum_routing_weights += routing_weight;
top_x[expert_idx].push(row_idx as u32);
}
for &expert_idx in dst.iter().take(*n_expert_used) {
let expert_idx = expert_idx as usize;
let routing_weight = rw[expert_idx];
selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
}
}
// routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
// expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
let mut ys = xs.zeros_like()?;
for (expert_idx, expert_layer) in experts.iter().enumerate() {
let top_x = &top_x[expert_idx];
if top_x.is_empty() {
continue;
}
let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
let selected_rws =
Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?
.reshape(((), 1))?;
// Index the correct hidden states and compute the expert hidden state for
// the current expert. We need to make sure to multiply the output hidden
// states by `routing_weights` on the corresponding tokens (top-1 and top-2)
let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
// current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
let current_hidden_states = expert_layer.forward(&current_state)?;
let current_hidden_states =
current_hidden_states.broadcast_mul(&selected_rws)?;
ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
}
let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
Ok(ys)
}
Self::Mlp(mlp) => mlp.forward(xs),
}
}
}
#[derive(Debug, Clone)]
struct LayerWeights {
attention_wq: QMatMul,
@ -54,9 +150,7 @@ struct LayerWeights {
attention_wv: QMatMul,
attention_wo: QMatMul,
attention_norm: RmsNorm,
feed_forward_w1: QMatMul,
feed_forward_w2: QMatMul,
feed_forward_w3: QMatMul,
mlp_or_moe: MlpOrMoe,
ffn_norm: RmsNorm,
n_head: usize,
n_kv_head: usize,
@ -212,9 +306,16 @@ impl ModelWeights {
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
let mlp_or_moe = {
let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
MlpOrMoe::Mlp(Mlp {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
})
};
let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
@ -226,9 +327,7 @@ impl ModelWeights {
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::new(attention_norm, 1e-5)?,
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
mlp_or_moe,
ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?,
n_head: ct.hparams.n_head as usize,
n_kv_head: ct.hparams.n_head as usize / gqa,
@ -265,6 +364,12 @@ impl ModelWeights {
};
// Parameter extraction from metadata.
let n_expert = md_get("llama.expert_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let n_expert_used = md_get("llama.expert_used_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
@ -289,9 +394,38 @@ impl ModelWeights {
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
let mlp_or_moe = if n_expert <= 1 {
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
MlpOrMoe::Mlp(Mlp {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
})
} else {
let feed_forward_gate_inp =
ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"))?;
let mut experts = Vec::with_capacity(n_expert);
for i in 0..n_expert {
let feed_forward_w1 =
ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"))?;
let feed_forward_w2 =
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"))?;
let feed_forward_w3 =
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"))?;
experts.push(Mlp {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
})
}
MlpOrMoe::MoE {
n_expert_used,
feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
experts,
}
};
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
@ -303,9 +437,7 @@ impl ModelWeights {
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?,
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
mlp_or_moe,
ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?,
n_head: head_count,
n_kv_head: head_count_kv,
@ -360,12 +492,9 @@ impl ModelWeights {
let _enter = layer.span_mlp.enter();
let residual = &x;
let x = layer.ffn_norm.forward(&x)?;
let w1 = layer.feed_forward_w1.forward(&x)?;
let w3 = layer.feed_forward_w3.forward(&x)?;
let mlp = layer
.feed_forward_w2
.forward(&(candle_nn::ops::silu(&w1)? * w3)?)?;
layer_in = (mlp + residual)?;
let x = layer.mlp_or_moe.forward(&x)?;
let x = (x + residual)?;
layer_in = x
}
let x = self.norm.forward(&layer_in)?;
let x = x.i((.., seq_len - 1, ..))?;