diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index ea7f70eb..070ecce1 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -67,6 +67,8 @@ enum Which { Mixtral, #[value(name = "mixtral-instruct")] MixtralInstruct, + #[value(name = "phi-2")] + Phi2, } impl Which { @@ -82,7 +84,8 @@ impl Which { | Self::L13bCode | Self::L34bCode | Self::Leo7b - | Self::Leo13b => false, + | Self::Leo13b + | Self::Phi2 => false, // Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the // same way. Starling is a fine tuned version of OpenChat. Self::OpenChat35 @@ -116,6 +119,7 @@ impl Which { | Self::Mistral7bInstruct | Self::Mistral7bInstructV02 | Self::OpenChat35 + | Self::Phi2 | Self::Starling7bAlpha => false, Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true, } @@ -139,6 +143,7 @@ impl Which { | Self::Mistral7b | Self::Mistral7bInstruct | Self::Mistral7bInstructV02 + | Self::Phi2 | Self::Zephyr7bAlpha | Self::Zephyr7bBeta => false, Self::OpenChat35 | Self::Starling7bAlpha => true, @@ -147,26 +152,27 @@ impl Which { fn tokenizer_repo(&self) -> &'static str { match self { - Which::L7b - | Which::L13b - | Which::L70b - | Which::L7bChat - | Which::L13bChat - | Which::L70bChat - | Which::L7bCode - | Which::L13bCode - | Which::L34bCode => "hf-internal-testing/llama-tokenizer", - Which::Leo7b => "LeoLM/leo-hessianai-7b", - Which::Leo13b => "LeoLM/leo-hessianai-13b", - Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1", - Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1", - Which::Mistral7b - | Which::Mistral7bInstruct - | Which::Mistral7bInstructV02 - | Which::Zephyr7bAlpha - | Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1", - Which::OpenChat35 => "openchat/openchat_3.5", - Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha", + Self::L7b + | Self::L13b + | Self::L70b + | Self::L7bChat + | Self::L13bChat + | Self::L70bChat + | Self::L7bCode + | Self::L13bCode + | Self::L34bCode => "hf-internal-testing/llama-tokenizer", + Self::Leo7b => "LeoLM/leo-hessianai-7b", + Self::Leo13b => "LeoLM/leo-hessianai-13b", + Self::Mixtral => "mistralai/Mixtral-8x7B-v0.1", + Self::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1", + Self::Mistral7b + | Self::Mistral7bInstruct + | Self::Mistral7bInstructV02 + | Self::Zephyr7bAlpha + | Self::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1", + Self::OpenChat35 => "openchat/openchat_3.5", + Self::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha", + Self::Phi2 => "microsoft/phi-2", } } } @@ -322,6 +328,7 @@ impl Args { "TheBloke/Starling-LM-7B-alpha-GGUF", "starling-lm-7b-alpha.Q4_K_M.gguf", ), + Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf"), }; let api = hf_hub::api::sync::Api::new()?; let api = api.model(repo.to_string()); @@ -420,7 +427,8 @@ fn main() -> anyhow::Result<()> { | Which::L13bCode | Which::L34bCode | Which::Leo7b - | Which::Leo13b => 1, + | Which::Leo13b + | Which::Phi2 => 1, Which::Mixtral | Which::MixtralInstruct | Which::Mistral7b diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 6b326fbe..5e163cb6 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -256,6 +256,58 @@ fn precomput_freqs_cis( Ok((cos, sin)) } +#[derive(Debug, Clone)] +struct MetadataConfig { + n_expert: usize, + n_expert_used: usize, + head_count: usize, + head_count_kv: usize, + block_count: usize, + embedding_length: usize, + rope_dim: usize, + rms_norm_eps: f64, + rope_freq_base: f32, +} + +impl MetadataConfig { + fn from_gguf(ct: &gguf_file::Content) -> Result { + let md_get = |s: &str| match ct.metadata.get(s) { + None => candle::bail!("cannot find {s} in metadata"), + Some(v) => Ok(v), + }; + + // Parameter extraction from metadata. + let n_expert = md_get("llama.expert_count") + .and_then(|v| v.to_u32()) + .unwrap_or(0) as usize; + let n_expert_used = md_get("llama.expert_used_count") + .and_then(|v| v.to_u32()) + .unwrap_or(0) as usize; + let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; + let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; + let block_count = md_get("llama.block_count")?.to_u32()? as usize; + let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; + let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; + // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. + let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; + + let rope_freq_base = md_get("llama.rope.freq_base") + .and_then(|m| m.to_f32()) + .unwrap_or(10000f32); + Ok(Self { + n_expert, + n_expert_used, + head_count, + head_count_kv, + block_count, + embedding_length, + rope_freq_base, + rope_dim, + rms_norm_eps, + }) + } +} + impl ModelWeights { pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result { let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; @@ -325,48 +377,27 @@ impl ModelWeights { reader: &mut R, device: &Device, ) -> Result { - let md_get = |s: &str| match ct.metadata.get(s) { - None => candle::bail!("cannot find {s} in metadata"), - Some(v) => Ok(v), - }; + let cfg = MetadataConfig::from_gguf(&ct)?; - // Parameter extraction from metadata. - let n_expert = md_get("llama.expert_count") - .and_then(|v| v.to_u32()) - .unwrap_or(0) as usize; - let n_expert_used = md_get("llama.expert_used_count") - .and_then(|v| v.to_u32()) - .unwrap_or(0) as usize; - let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; - let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; - let block_count = md_get("llama.block_count")?.to_u32()? as usize; - let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; - let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; - // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. - let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; - - let rope_freq_base = md_get("llama.rope.freq_base") - .and_then(|m| m.to_f32()) - .unwrap_or(10000f32); - let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?; + let (cos, sin) = precomput_freqs_cis(cfg.rope_dim, cfg.rope_freq_base, device)?; let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; let norm = RmsNorm::from_qtensor( ct.tensor(reader, "output_norm.weight", device)?, - rms_norm_eps, + cfg.rms_norm_eps, )?; let output = ct.tensor(reader, "output.weight", device)?; - let mut layers = Vec::with_capacity(block_count); - for layer_idx in 0..block_count { + let mut layers = Vec::with_capacity(cfg.block_count); + for layer_idx in 0..cfg.block_count { let prefix = format!("blk.{layer_idx}"); let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?; let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?; let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; - let mlp_or_moe = if n_expert <= 1 { + let mlp_or_moe = if cfg.n_expert <= 1 { let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; let feed_forward_w2 = @@ -381,8 +412,8 @@ impl ModelWeights { } else { let feed_forward_gate_inp = ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?; - let mut experts = Vec::with_capacity(n_expert); - for i in 0..n_expert { + let mut experts = Vec::with_capacity(cfg.n_expert); + for i in 0..cfg.n_expert { let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?; let feed_forward_w2 = @@ -396,7 +427,7 @@ impl ModelWeights { }) } MlpOrMoe::MoE { - n_expert_used, + n_expert_used: cfg.n_expert_used, feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?, experts, } @@ -412,12 +443,12 @@ impl ModelWeights { attention_wk: QMatMul::from_qtensor(attention_wk)?, attention_wv: QMatMul::from_qtensor(attention_wv)?, attention_wo: QMatMul::from_qtensor(attention_wo)?, - attention_norm: RmsNorm::from_qtensor(attention_norm, rms_norm_eps)?, + attention_norm: RmsNorm::from_qtensor(attention_norm, cfg.rms_norm_eps)?, mlp_or_moe, - ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?, - n_head: head_count, - n_kv_head: head_count_kv, - head_dim: embedding_length / head_count, + ffn_norm: RmsNorm::from_qtensor(ffn_norm, cfg.rms_norm_eps)?, + n_head: cfg.head_count, + n_kv_head: cfg.head_count_kv, + head_dim: cfg.embedding_length / cfg.head_count, cos: cos.clone(), sin: sin.clone(), neg_inf: neg_inf.clone(), @@ -430,7 +461,7 @@ impl ModelWeights { let span = tracing::span!(tracing::Level::TRACE, "model"); let span_output = tracing::span!(tracing::Level::TRACE, "output"); Ok(Self { - tok_embeddings: Embedding::new(tok_embeddings, embedding_length), + tok_embeddings: Embedding::new(tok_embeddings, cfg.embedding_length), layers, norm, output: QMatMul::from_qtensor(output)?,