mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
GGUF support in the quantized model. (#559)
* GGUF support in the quantized model. * Get the GGUF support to work on llama.
This commit is contained in:
@ -10,8 +10,8 @@ use std::collections::HashMap;
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::ggml_file::Content;
|
||||
use candle::quantized::QTensor;
|
||||
use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, Module};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
@ -195,24 +195,26 @@ impl WeightMap {
|
||||
}
|
||||
}
|
||||
|
||||
fn precomput_freqs_cis(head_dim: usize) -> Result<(Tensor, Tensor)> {
|
||||
let theta: Vec<_> = (0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let cos = idx_theta.cos()?;
|
||||
let sin = idx_theta.sin()?;
|
||||
Ok((cos, sin))
|
||||
}
|
||||
|
||||
impl ModelWeights {
|
||||
fn new(mut ct: Content, gqa: usize) -> Result<Self> {
|
||||
fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
||||
let cpu = &Device::Cpu;
|
||||
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
||||
|
||||
// precompute freqs_cis
|
||||
let theta: Vec<_> = (0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let cos = idx_theta.cos()?;
|
||||
let sin = idx_theta.sin()?;
|
||||
|
||||
let (cos, sin) = precomput_freqs_cis(head_dim)?;
|
||||
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||
let norm = RmsNorm::new(ct.remove("norm.weight")?)?;
|
||||
@ -220,7 +222,7 @@ impl ModelWeights {
|
||||
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
|
||||
for layer_idx in 0..ct.hparams.n_layer {
|
||||
let prefix = format!("layers.{layer_idx}");
|
||||
let attention_wq = ct.remove(&format!("layers.{layer_idx}.attention.wq.weight"))?;
|
||||
let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?;
|
||||
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
|
||||
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
|
||||
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
|
||||
@ -266,6 +268,78 @@ impl ModelWeights {
|
||||
})
|
||||
}
|
||||
|
||||
fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||
ct: gguf_file::Content,
|
||||
reader: &mut R,
|
||||
) -> Result<Self> {
|
||||
let cpu = &Device::Cpu;
|
||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||
None => candle::bail!("cannot find {s} in metadata"),
|
||||
Some(v) => Ok(v),
|
||||
};
|
||||
|
||||
// Parameter extraction from metadata.
|
||||
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
|
||||
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
|
||||
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
|
||||
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
|
||||
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
|
||||
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim)?;
|
||||
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||
let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?)?;
|
||||
let output = ct.tensor(reader, "output.weight")?;
|
||||
let mut layers = Vec::with_capacity(block_count);
|
||||
for layer_idx in 0..block_count {
|
||||
let prefix = format!("blk.{layer_idx}");
|
||||
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?;
|
||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
|
||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
|
||||
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
|
||||
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
|
||||
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
|
||||
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
|
||||
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
|
||||
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||
layers.push(LayerWeights {
|
||||
attention_wq: QMatMul::from_qtensor(attention_wq),
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk),
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv),
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo),
|
||||
attention_norm: RmsNorm::new(attention_norm)?,
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
|
||||
ffn_norm: RmsNorm::new(ffn_norm)?,
|
||||
n_head: head_count,
|
||||
n_kv_head: head_count_kv,
|
||||
head_dim: embedding_length / head_count,
|
||||
cos: cos.clone(),
|
||||
sin: sin.clone(),
|
||||
kv_cache: None,
|
||||
span_attn,
|
||||
span_rot,
|
||||
span_mlp,
|
||||
})
|
||||
}
|
||||
let span = tracing::span!(tracing::Level::TRACE, "model");
|
||||
let span_output = tracing::span!(tracing::Level::TRACE, "output");
|
||||
Ok(Self {
|
||||
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
|
||||
layers,
|
||||
norm,
|
||||
output: QMatMul::from_qtensor(output),
|
||||
masks: HashMap::new(),
|
||||
span,
|
||||
span_output,
|
||||
})
|
||||
}
|
||||
|
||||
fn mask(&mut self, t: usize) -> Result<Tensor> {
|
||||
if let Some(mask) = self.masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
@ -443,6 +517,18 @@ fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Resul
|
||||
Tensor::from_vec(logits, logits_len, &Device::Cpu)
|
||||
}
|
||||
|
||||
fn format_size(size_in_bytes: usize) -> String {
|
||||
if size_in_bytes < 1_000 {
|
||||
format!("{}B", size_in_bytes)
|
||||
} else if size_in_bytes < 1_000_000 {
|
||||
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
|
||||
} else if size_in_bytes < 1_000_000_000 {
|
||||
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
|
||||
} else {
|
||||
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
@ -464,37 +550,49 @@ fn main() -> anyhow::Result<()> {
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
|
||||
let mut file = std::fs::File::open(&args.model()?)?;
|
||||
let model_path = args.model()?;
|
||||
let mut file = std::fs::File::open(&model_path)?;
|
||||
let start = std::time::Instant::now();
|
||||
let model = Content::read(&mut file)?;
|
||||
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensors.iter() {
|
||||
let elem_count = tensor.shape().elem_count();
|
||||
total_size_in_bytes += elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
|
||||
}
|
||||
let total_size = if total_size_in_bytes < 1_000 {
|
||||
format!("{}B", total_size_in_bytes)
|
||||
} else if total_size_in_bytes < 1_000_000 {
|
||||
format!("{:.2}KB", total_size_in_bytes as f64 / 1e3)
|
||||
} else if total_size_in_bytes < 1_000_000_000 {
|
||||
format!("{:.2}MB", total_size_in_bytes as f64 / 1e6)
|
||||
} else {
|
||||
format!("{:.2}GB", total_size_in_bytes as f64 / 1e9)
|
||||
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
|
||||
Some("gguf") => {
|
||||
let model = gguf_file::Content::read(&mut file)?;
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensor_infos.iter() {
|
||||
let elem_count = tensor.shape.elem_count();
|
||||
total_size_in_bytes +=
|
||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
|
||||
}
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
model.tensor_infos.len(),
|
||||
&format_size(total_size_in_bytes),
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
ModelWeights::from_gguf(model, &mut file)?
|
||||
}
|
||||
Some("ggml" | "bin") | Some(_) | None => {
|
||||
let model = ggml_file::Content::read(&mut file)?;
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensors.iter() {
|
||||
let elem_count = tensor.shape().elem_count();
|
||||
total_size_in_bytes +=
|
||||
elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
|
||||
}
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
model.tensors.len(),
|
||||
&format_size(total_size_in_bytes),
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
println!("params: {:?}", model.hparams);
|
||||
let default_gqa = match args.which {
|
||||
Which::L7b | Which::L13b => 1,
|
||||
Which::L70b => 8,
|
||||
};
|
||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
||||
}
|
||||
};
|
||||
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
model.tensors.len(),
|
||||
total_size,
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
println!("params: {:?}", model.hparams);
|
||||
let default_gqa = match args.which {
|
||||
Which::L7b | Which::L13b => 1,
|
||||
Which::L70b => 8,
|
||||
};
|
||||
let mut model = ModelWeights::new(model, args.gqa.unwrap_or(default_gqa))?;
|
||||
println!("model built");
|
||||
|
||||
let tokenizer = args.tokenizer()?;
|
||||
|
Reference in New Issue
Block a user