mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Quantized mixtral model (#1442)
* Add the Mixtral model. * Add more of the mixtral layers. * Add the final layers for mixtral. * Sketch the expert selection. * Add some expert routing logic. * Hopefully finish the routing logic for mixtral. * Add the mixtral example. * Fix the weight filenames. * Bugfix. * Another fix. * Yet another fix + remove the unused pragma. * Shape fix. * Support for quantized mixtral. * Support mixtral in the quantized example. * Mlp or moe type. * Fix the expert field namings. * Refactor the mlp bit. * More MoE logic. * Add the MoE quantized logic. * Fix the experts length.
This commit is contained in:
@ -61,6 +61,8 @@ enum Which {
|
||||
OpenChat35,
|
||||
#[value(name = "7b-starling-a")]
|
||||
Starling7bAlpha,
|
||||
#[value(name = "mixtral")]
|
||||
Mixtral,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
@ -83,6 +85,7 @@ impl Which {
|
||||
| Self::Starling7bAlpha
|
||||
| Self::Zephyr7bAlpha
|
||||
| Self::Zephyr7bBeta
|
||||
| Self::Mixtral
|
||||
| Self::Mistral7b
|
||||
| Self::Mistral7bInstruct => true,
|
||||
}
|
||||
@ -101,6 +104,7 @@ impl Which {
|
||||
| Self::L34bCode
|
||||
| Self::Leo7b
|
||||
| Self::Leo13b
|
||||
| Self::Mixtral
|
||||
| Self::Mistral7b
|
||||
| Self::Mistral7bInstruct
|
||||
| Self::OpenChat35
|
||||
@ -122,6 +126,7 @@ impl Which {
|
||||
| Self::L34bCode
|
||||
| Self::Leo7b
|
||||
| Self::Leo13b
|
||||
| Self::Mixtral
|
||||
| Self::Mistral7b
|
||||
| Self::Mistral7bInstruct
|
||||
| Self::Zephyr7bAlpha
|
||||
@ -143,6 +148,7 @@ impl Which {
|
||||
| Which::L34bCode => "hf-internal-testing/llama-tokenizer",
|
||||
Which::Leo7b => "LeoLM/leo-hessianai-7b",
|
||||
Which::Leo13b => "LeoLM/leo-hessianai-13b",
|
||||
Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
|
||||
Which::Mistral7b
|
||||
| Which::Mistral7bInstruct
|
||||
| Which::Zephyr7bAlpha
|
||||
@ -256,6 +262,10 @@ impl Args {
|
||||
"TheBloke/leo-hessianai-13B-GGUF",
|
||||
"leo-hessianai-13b.Q4_K_M.gguf",
|
||||
),
|
||||
Which::Mixtral => (
|
||||
"TheBloke/Mixtral-8x7B-v0.1-GGUF",
|
||||
"mixtral-8x7b-v0.1.Q4_K_M.gguf",
|
||||
),
|
||||
Which::Mistral7b => (
|
||||
"TheBloke/Mistral-7B-v0.1-GGUF",
|
||||
"mistral-7b-v0.1.Q4_K_S.gguf",
|
||||
@ -374,7 +384,8 @@ fn main() -> anyhow::Result<()> {
|
||||
| Which::L34bCode
|
||||
| Which::Leo7b
|
||||
| Which::Leo13b => 1,
|
||||
Which::Mistral7b
|
||||
Which::Mixtral
|
||||
| Which::Mistral7b
|
||||
| Which::Mistral7bInstruct
|
||||
| Which::Zephyr7bAlpha
|
||||
| Which::Zephyr7bBeta
|
||||
|
@ -47,6 +47,102 @@ impl QMatMul {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
feed_forward_w1: QMatMul,
|
||||
feed_forward_w2: QMatMul,
|
||||
feed_forward_w3: QMatMul,
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let w1 = self.feed_forward_w1.forward(xs)?;
|
||||
let w3 = self.feed_forward_w3.forward(xs)?;
|
||||
self.feed_forward_w2
|
||||
.forward(&(candle_nn::ops::silu(&w1)? * w3)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum MlpOrMoe {
|
||||
Mlp(Mlp),
|
||||
MoE {
|
||||
n_expert_used: usize,
|
||||
feed_forward_gate_inp: QMatMul,
|
||||
experts: Vec<Mlp>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Module for MlpOrMoe {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::MoE {
|
||||
feed_forward_gate_inp,
|
||||
experts,
|
||||
n_expert_used,
|
||||
} => {
|
||||
let (b_size, seq_len, hidden_dim) = xs.dims3()?;
|
||||
let xs = xs.reshape(((), hidden_dim))?;
|
||||
let router_logits = feed_forward_gate_inp.forward(&xs)?;
|
||||
let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
|
||||
|
||||
// In order to extract topk, we extract the data from the tensor and manipulate it
|
||||
// directly. Maybe we will want to use some custom ops instead at some point.
|
||||
let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
|
||||
|
||||
// routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
// top_x contains the row indexes to evaluate for each expert.
|
||||
let mut top_x = vec![vec![]; experts.len()];
|
||||
let mut selected_rws = vec![vec![]; experts.len()];
|
||||
for (row_idx, rw) in routing_weights.iter().enumerate() {
|
||||
let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
|
||||
dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
|
||||
let mut sum_routing_weights = 0f32;
|
||||
for &expert_idx in dst.iter().take(*n_expert_used) {
|
||||
let expert_idx = expert_idx as usize;
|
||||
let routing_weight = rw[expert_idx];
|
||||
sum_routing_weights += routing_weight;
|
||||
top_x[expert_idx].push(row_idx as u32);
|
||||
}
|
||||
for &expert_idx in dst.iter().take(*n_expert_used) {
|
||||
let expert_idx = expert_idx as usize;
|
||||
let routing_weight = rw[expert_idx];
|
||||
selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
|
||||
}
|
||||
}
|
||||
|
||||
// routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
// expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
let mut ys = xs.zeros_like()?;
|
||||
for (expert_idx, expert_layer) in experts.iter().enumerate() {
|
||||
let top_x = &top_x[expert_idx];
|
||||
if top_x.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
|
||||
let selected_rws =
|
||||
Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?
|
||||
.reshape(((), 1))?;
|
||||
// Index the correct hidden states and compute the expert hidden state for
|
||||
// the current expert. We need to make sure to multiply the output hidden
|
||||
// states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
||||
let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
|
||||
// current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
|
||||
let current_hidden_states = expert_layer.forward(¤t_state)?;
|
||||
let current_hidden_states =
|
||||
current_hidden_states.broadcast_mul(&selected_rws)?;
|
||||
ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
|
||||
}
|
||||
|
||||
let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
|
||||
Ok(ys)
|
||||
}
|
||||
Self::Mlp(mlp) => mlp.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct LayerWeights {
|
||||
attention_wq: QMatMul,
|
||||
@ -54,9 +150,7 @@ struct LayerWeights {
|
||||
attention_wv: QMatMul,
|
||||
attention_wo: QMatMul,
|
||||
attention_norm: RmsNorm,
|
||||
feed_forward_w1: QMatMul,
|
||||
feed_forward_w2: QMatMul,
|
||||
feed_forward_w3: QMatMul,
|
||||
mlp_or_moe: MlpOrMoe,
|
||||
ffn_norm: RmsNorm,
|
||||
n_head: usize,
|
||||
n_kv_head: usize,
|
||||
@ -212,9 +306,16 @@ impl ModelWeights {
|
||||
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
|
||||
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
|
||||
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
|
||||
let mlp_or_moe = {
|
||||
let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
|
||||
let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
|
||||
let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
|
||||
MlpOrMoe::Mlp(Mlp {
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
})
|
||||
};
|
||||
let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
|
||||
let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
@ -226,9 +327,7 @@ impl ModelWeights {
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||
attention_norm: RmsNorm::new(attention_norm, 1e-5)?,
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
mlp_or_moe,
|
||||
ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?,
|
||||
n_head: ct.hparams.n_head as usize,
|
||||
n_kv_head: ct.hparams.n_head as usize / gqa,
|
||||
@ -265,6 +364,12 @@ impl ModelWeights {
|
||||
};
|
||||
|
||||
// Parameter extraction from metadata.
|
||||
let n_expert = md_get("llama.expert_count")
|
||||
.and_then(|v| v.to_u32())
|
||||
.unwrap_or(0) as usize;
|
||||
let n_expert_used = md_get("llama.expert_used_count")
|
||||
.and_then(|v| v.to_u32())
|
||||
.unwrap_or(0) as usize;
|
||||
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
|
||||
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
|
||||
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
|
||||
@ -289,9 +394,38 @@ impl ModelWeights {
|
||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
|
||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
|
||||
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
|
||||
let mlp_or_moe = if n_expert <= 1 {
|
||||
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
|
||||
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
|
||||
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
|
||||
MlpOrMoe::Mlp(Mlp {
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
})
|
||||
} else {
|
||||
let feed_forward_gate_inp =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"))?;
|
||||
let mut experts = Vec::with_capacity(n_expert);
|
||||
for i in 0..n_expert {
|
||||
let feed_forward_w1 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"))?;
|
||||
let feed_forward_w2 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"))?;
|
||||
let feed_forward_w3 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"))?;
|
||||
experts.push(Mlp {
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
})
|
||||
}
|
||||
MlpOrMoe::MoE {
|
||||
n_expert_used,
|
||||
feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
|
||||
experts,
|
||||
}
|
||||
};
|
||||
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
|
||||
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
@ -303,9 +437,7 @@ impl ModelWeights {
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||
attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?,
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
mlp_or_moe,
|
||||
ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?,
|
||||
n_head: head_count,
|
||||
n_kv_head: head_count_kv,
|
||||
@ -360,12 +492,9 @@ impl ModelWeights {
|
||||
let _enter = layer.span_mlp.enter();
|
||||
let residual = &x;
|
||||
let x = layer.ffn_norm.forward(&x)?;
|
||||
let w1 = layer.feed_forward_w1.forward(&x)?;
|
||||
let w3 = layer.feed_forward_w3.forward(&x)?;
|
||||
let mlp = layer
|
||||
.feed_forward_w2
|
||||
.forward(&(candle_nn::ops::silu(&w1)? * w3)?)?;
|
||||
layer_in = (mlp + residual)?;
|
||||
let x = layer.mlp_or_moe.forward(&x)?;
|
||||
let x = (x + residual)?;
|
||||
layer_in = x
|
||||
}
|
||||
let x = self.norm.forward(&layer_in)?;
|
||||
let x = x.i((.., seq_len - 1, ..))?;
|
||||
|
Reference in New Issue
Block a user