Prepare for supporting phi-2 properly in the quantized model.

This commit is contained in:
laurent
2024-04-17 09:14:38 +02:00
parent 2817643db9
commit af11b2d461
2 changed files with 97 additions and 58 deletions

View File

@ -67,6 +67,8 @@ enum Which {
Mixtral,
#[value(name = "mixtral-instruct")]
MixtralInstruct,
#[value(name = "phi-2")]
Phi2,
}
impl Which {
@ -82,7 +84,8 @@ impl Which {
| Self::L13bCode
| Self::L34bCode
| Self::Leo7b
| Self::Leo13b => false,
| Self::Leo13b
| Self::Phi2 => false,
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
// same way. Starling is a fine tuned version of OpenChat.
Self::OpenChat35
@ -116,6 +119,7 @@ impl Which {
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::OpenChat35
| Self::Phi2
| Self::Starling7bAlpha => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
}
@ -139,6 +143,7 @@ impl Which {
| Self::Mistral7b
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::Phi2
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta => false,
Self::OpenChat35 | Self::Starling7bAlpha => true,
@ -147,26 +152,27 @@ impl Which {
fn tokenizer_repo(&self) -> &'static str {
match self {
Which::L7b
| Which::L13b
| Which::L70b
| Which::L7bChat
| Which::L13bChat
| Which::L70bChat
| Which::L7bCode
| Which::L13bCode
| Which::L34bCode => "hf-internal-testing/llama-tokenizer",
Which::Leo7b => "LeoLM/leo-hessianai-7b",
Which::Leo13b => "LeoLM/leo-hessianai-13b",
Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Mistral7bInstructV02
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
Which::OpenChat35 => "openchat/openchat_3.5",
Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
Self::L7b
| Self::L13b
| Self::L70b
| Self::L7bChat
| Self::L13bChat
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode => "hf-internal-testing/llama-tokenizer",
Self::Leo7b => "LeoLM/leo-hessianai-7b",
Self::Leo13b => "LeoLM/leo-hessianai-13b",
Self::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
Self::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
Self::Mistral7b
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
Self::OpenChat35 => "openchat/openchat_3.5",
Self::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
Self::Phi2 => "microsoft/phi-2",
}
}
}
@ -322,6 +328,7 @@ impl Args {
"TheBloke/Starling-LM-7B-alpha-GGUF",
"starling-lm-7b-alpha.Q4_K_M.gguf",
),
Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf"),
};
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string());
@ -420,7 +427,8 @@ fn main() -> anyhow::Result<()> {
| Which::L13bCode
| Which::L34bCode
| Which::Leo7b
| Which::Leo13b => 1,
| Which::Leo13b
| Which::Phi2 => 1,
Which::Mixtral
| Which::MixtralInstruct
| Which::Mistral7b

View File

@ -256,6 +256,58 @@ fn precomput_freqs_cis(
Ok((cos, sin))
}
#[derive(Debug, Clone)]
struct MetadataConfig {
n_expert: usize,
n_expert_used: usize,
head_count: usize,
head_count_kv: usize,
block_count: usize,
embedding_length: usize,
rope_dim: usize,
rms_norm_eps: f64,
rope_freq_base: f32,
}
impl MetadataConfig {
fn from_gguf(ct: &gguf_file::Content) -> Result<Self> {
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
};
// Parameter extraction from metadata.
let n_expert = md_get("llama.expert_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let n_expert_used = md_get("llama.expert_used_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
let rope_freq_base = md_get("llama.rope.freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(10000f32);
Ok(Self {
n_expert,
n_expert_used,
head_count,
head_count_kv,
block_count,
embedding_length,
rope_freq_base,
rope_dim,
rms_norm_eps,
})
}
}
impl ModelWeights {
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
@ -325,48 +377,27 @@ impl ModelWeights {
reader: &mut R,
device: &Device,
) -> Result<Self> {
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
};
let cfg = MetadataConfig::from_gguf(&ct)?;
// Parameter extraction from metadata.
let n_expert = md_get("llama.expert_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let n_expert_used = md_get("llama.expert_used_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
let rope_freq_base = md_get("llama.rope.freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(10000f32);
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
let (cos, sin) = precomput_freqs_cis(cfg.rope_dim, cfg.rope_freq_base, device)?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
let norm = RmsNorm::from_qtensor(
ct.tensor(reader, "output_norm.weight", device)?,
rms_norm_eps,
cfg.rms_norm_eps,
)?;
let output = ct.tensor(reader, "output.weight", device)?;
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
let mut layers = Vec::with_capacity(cfg.block_count);
for layer_idx in 0..cfg.block_count {
let prefix = format!("blk.{layer_idx}");
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
let attention_wo =
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
let mlp_or_moe = if n_expert <= 1 {
let mlp_or_moe = if cfg.n_expert <= 1 {
let feed_forward_w1 =
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
let feed_forward_w2 =
@ -381,8 +412,8 @@ impl ModelWeights {
} else {
let feed_forward_gate_inp =
ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?;
let mut experts = Vec::with_capacity(n_expert);
for i in 0..n_expert {
let mut experts = Vec::with_capacity(cfg.n_expert);
for i in 0..cfg.n_expert {
let feed_forward_w1 =
ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?;
let feed_forward_w2 =
@ -396,7 +427,7 @@ impl ModelWeights {
})
}
MlpOrMoe::MoE {
n_expert_used,
n_expert_used: cfg.n_expert_used,
feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
experts,
}
@ -412,12 +443,12 @@ impl ModelWeights {
attention_wk: QMatMul::from_qtensor(attention_wk)?,
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::from_qtensor(attention_norm, rms_norm_eps)?,
attention_norm: RmsNorm::from_qtensor(attention_norm, cfg.rms_norm_eps)?,
mlp_or_moe,
ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?,
n_head: head_count,
n_kv_head: head_count_kv,
head_dim: embedding_length / head_count,
ffn_norm: RmsNorm::from_qtensor(ffn_norm, cfg.rms_norm_eps)?,
n_head: cfg.head_count,
n_kv_head: cfg.head_count_kv,
head_dim: cfg.embedding_length / cfg.head_count,
cos: cos.clone(),
sin: sin.clone(),
neg_inf: neg_inf.clone(),
@ -430,7 +461,7 @@ impl ModelWeights {
let span = tracing::span!(tracing::Level::TRACE, "model");
let span_output = tracing::span!(tracing::Level::TRACE, "output");
Ok(Self {
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
tok_embeddings: Embedding::new(tok_embeddings, cfg.embedding_length),
layers,
norm,
output: QMatMul::from_qtensor(output)?,