mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
3 Commits
Author | SHA1 | Date | |
---|---|---|---|
3754b834f4 | |||
d79041d94d | |||
af11b2d461 |
@ -67,6 +67,8 @@ enum Which {
|
|||||||
Mixtral,
|
Mixtral,
|
||||||
#[value(name = "mixtral-instruct")]
|
#[value(name = "mixtral-instruct")]
|
||||||
MixtralInstruct,
|
MixtralInstruct,
|
||||||
|
#[value(name = "phi-2")]
|
||||||
|
Phi2,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Which {
|
impl Which {
|
||||||
@ -82,7 +84,8 @@ impl Which {
|
|||||||
| Self::L13bCode
|
| Self::L13bCode
|
||||||
| Self::L34bCode
|
| Self::L34bCode
|
||||||
| Self::Leo7b
|
| Self::Leo7b
|
||||||
| Self::Leo13b => false,
|
| Self::Leo13b
|
||||||
|
| Self::Phi2 => false,
|
||||||
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
||||||
// same way. Starling is a fine tuned version of OpenChat.
|
// same way. Starling is a fine tuned version of OpenChat.
|
||||||
Self::OpenChat35
|
Self::OpenChat35
|
||||||
@ -116,6 +119,7 @@ impl Which {
|
|||||||
| Self::Mistral7bInstruct
|
| Self::Mistral7bInstruct
|
||||||
| Self::Mistral7bInstructV02
|
| Self::Mistral7bInstructV02
|
||||||
| Self::OpenChat35
|
| Self::OpenChat35
|
||||||
|
| Self::Phi2
|
||||||
| Self::Starling7bAlpha => false,
|
| Self::Starling7bAlpha => false,
|
||||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||||
}
|
}
|
||||||
@ -139,6 +143,7 @@ impl Which {
|
|||||||
| Self::Mistral7b
|
| Self::Mistral7b
|
||||||
| Self::Mistral7bInstruct
|
| Self::Mistral7bInstruct
|
||||||
| Self::Mistral7bInstructV02
|
| Self::Mistral7bInstructV02
|
||||||
|
| Self::Phi2
|
||||||
| Self::Zephyr7bAlpha
|
| Self::Zephyr7bAlpha
|
||||||
| Self::Zephyr7bBeta => false,
|
| Self::Zephyr7bBeta => false,
|
||||||
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
||||||
@ -147,26 +152,27 @@ impl Which {
|
|||||||
|
|
||||||
fn tokenizer_repo(&self) -> &'static str {
|
fn tokenizer_repo(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Which::L7b
|
Self::L7b
|
||||||
| Which::L13b
|
| Self::L13b
|
||||||
| Which::L70b
|
| Self::L70b
|
||||||
| Which::L7bChat
|
| Self::L7bChat
|
||||||
| Which::L13bChat
|
| Self::L13bChat
|
||||||
| Which::L70bChat
|
| Self::L70bChat
|
||||||
| Which::L7bCode
|
| Self::L7bCode
|
||||||
| Which::L13bCode
|
| Self::L13bCode
|
||||||
| Which::L34bCode => "hf-internal-testing/llama-tokenizer",
|
| Self::L34bCode => "hf-internal-testing/llama-tokenizer",
|
||||||
Which::Leo7b => "LeoLM/leo-hessianai-7b",
|
Self::Leo7b => "LeoLM/leo-hessianai-7b",
|
||||||
Which::Leo13b => "LeoLM/leo-hessianai-13b",
|
Self::Leo13b => "LeoLM/leo-hessianai-13b",
|
||||||
Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
|
Self::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
|
||||||
Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
Self::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
Which::Mistral7b
|
Self::Mistral7b
|
||||||
| Which::Mistral7bInstruct
|
| Self::Mistral7bInstruct
|
||||||
| Which::Mistral7bInstructV02
|
| Self::Mistral7bInstructV02
|
||||||
| Which::Zephyr7bAlpha
|
| Self::Zephyr7bAlpha
|
||||||
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
|
| Self::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
|
||||||
Which::OpenChat35 => "openchat/openchat_3.5",
|
Self::OpenChat35 => "openchat/openchat_3.5",
|
||||||
Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
|
Self::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
|
||||||
|
Self::Phi2 => "microsoft/phi-2",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -322,6 +328,7 @@ impl Args {
|
|||||||
"TheBloke/Starling-LM-7B-alpha-GGUF",
|
"TheBloke/Starling-LM-7B-alpha-GGUF",
|
||||||
"starling-lm-7b-alpha.Q4_K_M.gguf",
|
"starling-lm-7b-alpha.Q4_K_M.gguf",
|
||||||
),
|
),
|
||||||
|
Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf"),
|
||||||
};
|
};
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model(repo.to_string());
|
let api = api.model(repo.to_string());
|
||||||
@ -420,7 +427,8 @@ fn main() -> anyhow::Result<()> {
|
|||||||
| Which::L13bCode
|
| Which::L13bCode
|
||||||
| Which::L34bCode
|
| Which::L34bCode
|
||||||
| Which::Leo7b
|
| Which::Leo7b
|
||||||
| Which::Leo13b => 1,
|
| Which::Leo13b
|
||||||
|
| Which::Phi2 => 1,
|
||||||
Which::Mixtral
|
Which::Mixtral
|
||||||
| Which::MixtralInstruct
|
| Which::MixtralInstruct
|
||||||
| Which::Mistral7b
|
| Which::Mistral7b
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use crate::quantized_nn::RmsNorm;
|
|
||||||
use candle::quantized::QTensor;
|
use candle::quantized::QTensor;
|
||||||
use candle::quantized::{ggml_file, gguf_file};
|
use candle::quantized::{ggml_file, gguf_file};
|
||||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||||
@ -29,13 +28,13 @@ impl QMatMul {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct Mlp {
|
struct MlpSilu {
|
||||||
feed_forward_w1: QMatMul,
|
feed_forward_w1: QMatMul,
|
||||||
feed_forward_w2: QMatMul,
|
feed_forward_w2: QMatMul,
|
||||||
feed_forward_w3: QMatMul,
|
feed_forward_w3: QMatMul,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for Mlp {
|
impl Module for MlpSilu {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let w1 = self.feed_forward_w1.forward(xs)?;
|
let w1 = self.feed_forward_w1.forward(xs)?;
|
||||||
let w3 = self.feed_forward_w3.forward(xs)?;
|
let w3 = self.feed_forward_w3.forward(xs)?;
|
||||||
@ -45,16 +44,31 @@ impl Module for Mlp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
enum MlpOrMoe {
|
struct MlpSimple {
|
||||||
Mlp(Mlp),
|
fc1: QMatMul,
|
||||||
|
fc2: QMatMul,
|
||||||
|
act: candle_nn::Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for MlpSimple {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let xs = self.fc1.forward(xs)?.apply(&self.act)?;
|
||||||
|
self.fc2.forward(&xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
enum Mlp {
|
||||||
|
Silu(MlpSilu),
|
||||||
|
Simple(MlpSimple),
|
||||||
MoE {
|
MoE {
|
||||||
n_expert_used: usize,
|
n_expert_used: usize,
|
||||||
feed_forward_gate_inp: QMatMul,
|
feed_forward_gate_inp: QMatMul,
|
||||||
experts: Vec<Mlp>,
|
experts: Vec<MlpSilu>,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for MlpOrMoe {
|
impl Module for Mlp {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Self::MoE {
|
Self::MoE {
|
||||||
@ -119,20 +133,48 @@ impl Module for MlpOrMoe {
|
|||||||
let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
|
let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
|
||||||
Ok(ys)
|
Ok(ys)
|
||||||
}
|
}
|
||||||
Self::Mlp(mlp) => mlp.forward(xs),
|
Self::Silu(mlp) => mlp.forward(xs),
|
||||||
|
Self::Simple(mlp) => mlp.forward(xs),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
enum Norm {
|
||||||
|
Rms(crate::quantized_nn::RmsNorm),
|
||||||
|
Layer(candle_nn::LayerNorm),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Norm {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::Rms(m) => m.forward(xs),
|
||||||
|
Self::Layer(m) => m.forward(xs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rms_norm(q: QTensor, eps: f64) -> Result<Norm> {
|
||||||
|
let rms = crate::quantized_nn::RmsNorm::from_qtensor(q, eps)?;
|
||||||
|
Ok(Norm::Rms(rms))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result<Norm> {
|
||||||
|
let w = w.dequantize(&w.device())?;
|
||||||
|
let b = b.dequantize(&b.device())?;
|
||||||
|
let ln = candle_nn::LayerNorm::new(w, b, eps);
|
||||||
|
Ok(Norm::Layer(ln))
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct LayerWeights {
|
struct LayerWeights {
|
||||||
attention_wq: QMatMul,
|
attention_wq: QMatMul,
|
||||||
attention_wk: QMatMul,
|
attention_wk: QMatMul,
|
||||||
attention_wv: QMatMul,
|
attention_wv: QMatMul,
|
||||||
attention_wo: QMatMul,
|
attention_wo: QMatMul,
|
||||||
attention_norm: RmsNorm,
|
attention_norm: Norm,
|
||||||
mlp_or_moe: MlpOrMoe,
|
mlp: Mlp,
|
||||||
ffn_norm: RmsNorm,
|
ffn_norm: Norm,
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
n_kv_head: usize,
|
n_kv_head: usize,
|
||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
@ -230,7 +272,7 @@ impl LayerWeights {
|
|||||||
pub struct ModelWeights {
|
pub struct ModelWeights {
|
||||||
tok_embeddings: Embedding,
|
tok_embeddings: Embedding,
|
||||||
layers: Vec<LayerWeights>,
|
layers: Vec<LayerWeights>,
|
||||||
norm: RmsNorm,
|
norm: Norm,
|
||||||
output: QMatMul,
|
output: QMatMul,
|
||||||
masks: HashMap<usize, Tensor>,
|
masks: HashMap<usize, Tensor>,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
@ -256,6 +298,99 @@ fn precomput_freqs_cis(
|
|||||||
Ok((cos, sin))
|
Ok((cos, sin))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
enum Architecture {
|
||||||
|
Llama,
|
||||||
|
Phi2,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct MetadataConfig {
|
||||||
|
n_expert: usize,
|
||||||
|
n_expert_used: usize,
|
||||||
|
head_count: usize,
|
||||||
|
head_count_kv: usize,
|
||||||
|
block_count: usize,
|
||||||
|
embedding_length: usize,
|
||||||
|
rope_dim: usize,
|
||||||
|
rms_norm_eps: f64,
|
||||||
|
rope_freq_base: f32,
|
||||||
|
architecture: Architecture,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MetadataConfig {
|
||||||
|
fn from_gguf(ct: &gguf_file::Content) -> Result<Self> {
|
||||||
|
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||||
|
None => candle::bail!("cannot find {s} in metadata"),
|
||||||
|
Some(v) => Ok(v),
|
||||||
|
};
|
||||||
|
|
||||||
|
let architecture = match md_get("general.architecture")
|
||||||
|
.and_then(|v| v.to_string())
|
||||||
|
.map(|v| v.as_str())
|
||||||
|
{
|
||||||
|
Ok("phi2") => Architecture::Phi2,
|
||||||
|
Err(_) | Ok(_) => Architecture::Llama,
|
||||||
|
};
|
||||||
|
|
||||||
|
let config = match architecture {
|
||||||
|
Architecture::Phi2 => {
|
||||||
|
let head_count = md_get("phi2.attention.head_count")?.to_u32()? as usize;
|
||||||
|
let head_count_kv = md_get("phi2.attention.head_count_kv")?.to_u32()? as usize;
|
||||||
|
let block_count = md_get("phi2.block_count")?.to_u32()? as usize;
|
||||||
|
let embedding_length = md_get("phi2.embedding_length")?.to_u32()? as usize;
|
||||||
|
let rope_dim = md_get("phi2.rope.dimension_count")?.to_u32()? as usize;
|
||||||
|
let rms_norm_eps = md_get("phi2.attention.layer_norm_epsilon")?.to_f32()? as f64;
|
||||||
|
Self {
|
||||||
|
n_expert: 1,
|
||||||
|
n_expert_used: 1,
|
||||||
|
head_count,
|
||||||
|
head_count_kv,
|
||||||
|
block_count,
|
||||||
|
embedding_length,
|
||||||
|
rope_freq_base: 10_000.,
|
||||||
|
rope_dim,
|
||||||
|
rms_norm_eps,
|
||||||
|
architecture,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Architecture::Llama => {
|
||||||
|
let n_expert = md_get("llama.expert_count")
|
||||||
|
.and_then(|v| v.to_u32())
|
||||||
|
.unwrap_or(0) as usize;
|
||||||
|
let n_expert_used = md_get("llama.expert_used_count")
|
||||||
|
.and_then(|v| v.to_u32())
|
||||||
|
.unwrap_or(0) as usize;
|
||||||
|
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
|
||||||
|
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
|
||||||
|
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
|
||||||
|
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
|
||||||
|
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
|
||||||
|
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
|
||||||
|
let rms_norm_eps =
|
||||||
|
md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
||||||
|
|
||||||
|
let rope_freq_base = md_get("llama.rope.freq_base")
|
||||||
|
.and_then(|m| m.to_f32())
|
||||||
|
.unwrap_or(10000f32);
|
||||||
|
Self {
|
||||||
|
n_expert,
|
||||||
|
n_expert_used,
|
||||||
|
head_count,
|
||||||
|
head_count_kv,
|
||||||
|
block_count,
|
||||||
|
embedding_length,
|
||||||
|
rope_freq_base,
|
||||||
|
rope_dim,
|
||||||
|
rms_norm_eps,
|
||||||
|
architecture,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ModelWeights {
|
impl ModelWeights {
|
||||||
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
||||||
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
||||||
@ -263,7 +398,7 @@ impl ModelWeights {
|
|||||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?;
|
let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?;
|
||||||
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
||||||
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
|
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
|
||||||
let norm = RmsNorm::from_qtensor(ct.remove("norm.weight")?, 1e-5)?;
|
let norm = rms_norm(ct.remove("norm.weight")?, 1e-5)?;
|
||||||
let output = ct.remove("output.weight")?;
|
let output = ct.remove("output.weight")?;
|
||||||
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
|
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
|
||||||
for layer_idx in 0..ct.hparams.n_layer {
|
for layer_idx in 0..ct.hparams.n_layer {
|
||||||
@ -272,11 +407,11 @@ impl ModelWeights {
|
|||||||
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
|
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
|
||||||
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
|
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
|
||||||
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
|
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
|
||||||
let mlp_or_moe = {
|
let mlp = {
|
||||||
let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
|
let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
|
||||||
let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
|
let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
|
||||||
let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
|
let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
|
||||||
MlpOrMoe::Mlp(Mlp {
|
Mlp::Silu(MlpSilu {
|
||||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||||
@ -292,9 +427,9 @@ impl ModelWeights {
|
|||||||
attention_wk: QMatMul::from_qtensor(attention_wk)?,
|
attention_wk: QMatMul::from_qtensor(attention_wk)?,
|
||||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||||
attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?,
|
attention_norm: rms_norm(attention_norm, 1e-5)?,
|
||||||
mlp_or_moe,
|
mlp,
|
||||||
ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?,
|
ffn_norm: rms_norm(ffn_norm, 1e-5)?,
|
||||||
n_head: ct.hparams.n_head as usize,
|
n_head: ct.hparams.n_head as usize,
|
||||||
n_kv_head: ct.hparams.n_head as usize / gqa,
|
n_kv_head: ct.hparams.n_head as usize / gqa,
|
||||||
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
|
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
|
||||||
@ -325,78 +460,71 @@ impl ModelWeights {
|
|||||||
reader: &mut R,
|
reader: &mut R,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
let cfg = MetadataConfig::from_gguf(&ct)?;
|
||||||
None => candle::bail!("cannot find {s} in metadata"),
|
|
||||||
Some(v) => Ok(v),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Parameter extraction from metadata.
|
let (cos, sin) = precomput_freqs_cis(cfg.rope_dim, cfg.rope_freq_base, device)?;
|
||||||
let n_expert = md_get("llama.expert_count")
|
|
||||||
.and_then(|v| v.to_u32())
|
|
||||||
.unwrap_or(0) as usize;
|
|
||||||
let n_expert_used = md_get("llama.expert_used_count")
|
|
||||||
.and_then(|v| v.to_u32())
|
|
||||||
.unwrap_or(0) as usize;
|
|
||||||
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
|
|
||||||
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
|
|
||||||
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
|
|
||||||
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
|
|
||||||
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
|
|
||||||
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
|
|
||||||
let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
|
||||||
|
|
||||||
let rope_freq_base = md_get("llama.rope.freq_base")
|
|
||||||
.and_then(|m| m.to_f32())
|
|
||||||
.unwrap_or(10000f32);
|
|
||||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
|
|
||||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||||
|
|
||||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||||
let norm = RmsNorm::from_qtensor(
|
let norm = rms_norm(
|
||||||
ct.tensor(reader, "output_norm.weight", device)?,
|
ct.tensor(reader, "output_norm.weight", device)?,
|
||||||
rms_norm_eps,
|
cfg.rms_norm_eps,
|
||||||
)?;
|
)?;
|
||||||
let output = ct.tensor(reader, "output.weight", device)?;
|
let output = ct.tensor(reader, "output.weight", device)?;
|
||||||
let mut layers = Vec::with_capacity(block_count);
|
let mut layers = Vec::with_capacity(cfg.block_count);
|
||||||
for layer_idx in 0..block_count {
|
for layer_idx in 0..cfg.block_count {
|
||||||
let prefix = format!("blk.{layer_idx}");
|
let prefix = format!("blk.{layer_idx}");
|
||||||
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
|
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
|
||||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
|
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
|
||||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
|
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
|
||||||
let attention_wo =
|
let attention_wo =
|
||||||
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
|
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
|
||||||
let mlp_or_moe = if n_expert <= 1 {
|
let mlp = if cfg.n_expert <= 1 {
|
||||||
let feed_forward_w1 =
|
match cfg.architecture {
|
||||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
Architecture::Llama => {
|
||||||
let feed_forward_w2 =
|
let feed_forward_w1 =
|
||||||
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
||||||
let feed_forward_w3 =
|
let feed_forward_w2 =
|
||||||
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||||
MlpOrMoe::Mlp(Mlp {
|
let feed_forward_w3 =
|
||||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
Mlp::Silu(MlpSilu {
|
||||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||||
})
|
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||||
|
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Architecture::Phi2 => {
|
||||||
|
let fc1 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||||
|
let fc2 =
|
||||||
|
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||||
|
Mlp::Simple(MlpSimple {
|
||||||
|
fc1: QMatMul::from_qtensor(fc1)?,
|
||||||
|
fc2: QMatMul::from_qtensor(fc2)?,
|
||||||
|
act: candle_nn::Activation::NewGelu,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
let feed_forward_gate_inp =
|
let feed_forward_gate_inp =
|
||||||
ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?;
|
ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?;
|
||||||
let mut experts = Vec::with_capacity(n_expert);
|
let mut experts = Vec::with_capacity(cfg.n_expert);
|
||||||
for i in 0..n_expert {
|
for i in 0..cfg.n_expert {
|
||||||
let feed_forward_w1 =
|
let feed_forward_w1 =
|
||||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?;
|
ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?;
|
||||||
let feed_forward_w2 =
|
let feed_forward_w2 =
|
||||||
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?;
|
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?;
|
||||||
let feed_forward_w3 =
|
let feed_forward_w3 =
|
||||||
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?;
|
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?;
|
||||||
experts.push(Mlp {
|
experts.push(MlpSilu {
|
||||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
MlpOrMoe::MoE {
|
Mlp::MoE {
|
||||||
n_expert_used,
|
n_expert_used: cfg.n_expert_used,
|
||||||
feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
|
feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
|
||||||
experts,
|
experts,
|
||||||
}
|
}
|
||||||
@ -412,12 +540,12 @@ impl ModelWeights {
|
|||||||
attention_wk: QMatMul::from_qtensor(attention_wk)?,
|
attention_wk: QMatMul::from_qtensor(attention_wk)?,
|
||||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||||
attention_norm: RmsNorm::from_qtensor(attention_norm, rms_norm_eps)?,
|
attention_norm: rms_norm(attention_norm, cfg.rms_norm_eps)?,
|
||||||
mlp_or_moe,
|
mlp,
|
||||||
ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?,
|
ffn_norm: rms_norm(ffn_norm, cfg.rms_norm_eps)?,
|
||||||
n_head: head_count,
|
n_head: cfg.head_count,
|
||||||
n_kv_head: head_count_kv,
|
n_kv_head: cfg.head_count_kv,
|
||||||
head_dim: embedding_length / head_count,
|
head_dim: cfg.embedding_length / cfg.head_count,
|
||||||
cos: cos.clone(),
|
cos: cos.clone(),
|
||||||
sin: sin.clone(),
|
sin: sin.clone(),
|
||||||
neg_inf: neg_inf.clone(),
|
neg_inf: neg_inf.clone(),
|
||||||
@ -430,7 +558,7 @@ impl ModelWeights {
|
|||||||
let span = tracing::span!(tracing::Level::TRACE, "model");
|
let span = tracing::span!(tracing::Level::TRACE, "model");
|
||||||
let span_output = tracing::span!(tracing::Level::TRACE, "output");
|
let span_output = tracing::span!(tracing::Level::TRACE, "output");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
|
tok_embeddings: Embedding::new(tok_embeddings, cfg.embedding_length),
|
||||||
layers,
|
layers,
|
||||||
norm,
|
norm,
|
||||||
output: QMatMul::from_qtensor(output)?,
|
output: QMatMul::from_qtensor(output)?,
|
||||||
@ -473,7 +601,7 @@ impl ModelWeights {
|
|||||||
let _enter = layer.span_mlp.enter();
|
let _enter = layer.span_mlp.enter();
|
||||||
let residual = &x;
|
let residual = &x;
|
||||||
let x = layer.ffn_norm.forward(&x)?;
|
let x = layer.ffn_norm.forward(&x)?;
|
||||||
let x = layer.mlp_or_moe.forward(&x)?;
|
let x = layer.mlp.forward(&x)?;
|
||||||
let x = (x + residual)?;
|
let x = (x + residual)?;
|
||||||
layer_in = x
|
layer_in = x
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user