mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
GGUF support in the quantized model. (#559)
* GGUF support in the quantized model. * Get the GGUF support to work on llama.
This commit is contained in:
@ -2,7 +2,7 @@
|
|||||||
//!
|
//!
|
||||||
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
||||||
|
|
||||||
use super::GgmlDType;
|
use super::{GgmlDType, QTensor};
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use byteorder::{LittleEndian, ReadBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -55,7 +55,7 @@ impl TensorInfo {
|
|||||||
&self,
|
&self,
|
||||||
reader: &mut R,
|
reader: &mut R,
|
||||||
tensor_data_offset: u64,
|
tensor_data_offset: u64,
|
||||||
) -> Result<super::QTensor> {
|
) -> Result<QTensor> {
|
||||||
let tensor_elems = self.shape.elem_count();
|
let tensor_elems = self.shape.elem_count();
|
||||||
let size_in_bytes =
|
let size_in_bytes =
|
||||||
tensor_elems * self.ggml_dtype.type_size() / self.ggml_dtype.blck_size();
|
tensor_elems * self.ggml_dtype.type_size() / self.ggml_dtype.blck_size();
|
||||||
@ -78,6 +78,10 @@ fn read_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
|
|||||||
let len = reader.read_u32::<LittleEndian>()?;
|
let len = reader.read_u32::<LittleEndian>()?;
|
||||||
let mut v = vec![0u8; len as usize];
|
let mut v = vec![0u8; len as usize];
|
||||||
reader.read_exact(&mut v)?;
|
reader.read_exact(&mut v)?;
|
||||||
|
// GGUF strings are supposed to be non-null terminated but in practice this happens.
|
||||||
|
while let Some(0) = v.last() {
|
||||||
|
v.pop();
|
||||||
|
}
|
||||||
// GGUF strings are utf8 encoded but there are cases that don't seem to be valid.
|
// GGUF strings are utf8 encoded but there are cases that don't seem to be valid.
|
||||||
Ok(String::from_utf8_lossy(&v).into_owned())
|
Ok(String::from_utf8_lossy(&v).into_owned())
|
||||||
}
|
}
|
||||||
@ -125,6 +129,76 @@ pub enum Value {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Value {
|
impl Value {
|
||||||
|
pub fn to_u8(&self) -> Result<u8> {
|
||||||
|
match self {
|
||||||
|
Self::U8(v) => Ok(*v),
|
||||||
|
v => crate::bail!("not a u8 {v:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_i8(&self) -> Result<i8> {
|
||||||
|
match self {
|
||||||
|
Self::I8(v) => Ok(*v),
|
||||||
|
v => crate::bail!("not a i8 {v:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_u16(&self) -> Result<u16> {
|
||||||
|
match self {
|
||||||
|
Self::U16(v) => Ok(*v),
|
||||||
|
v => crate::bail!("not a u16 {v:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_i16(&self) -> Result<i16> {
|
||||||
|
match self {
|
||||||
|
Self::I16(v) => Ok(*v),
|
||||||
|
v => crate::bail!("not a i16 {v:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_u32(&self) -> Result<u32> {
|
||||||
|
match self {
|
||||||
|
Self::U32(v) => Ok(*v),
|
||||||
|
v => crate::bail!("not a u32 {v:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_i32(&self) -> Result<i32> {
|
||||||
|
match self {
|
||||||
|
Self::I32(v) => Ok(*v),
|
||||||
|
v => crate::bail!("not a i32 {v:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_f32(&self) -> Result<f32> {
|
||||||
|
match self {
|
||||||
|
Self::F32(v) => Ok(*v),
|
||||||
|
v => crate::bail!("not a f32 {v:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_bool(&self) -> Result<bool> {
|
||||||
|
match self {
|
||||||
|
Self::Bool(v) => Ok(*v),
|
||||||
|
v => crate::bail!("not a bool {v:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_vec(&self) -> Result<&Vec<Value>> {
|
||||||
|
match self {
|
||||||
|
Self::Array(v) => Ok(v),
|
||||||
|
v => crate::bail!("not a vec {v:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_string(&self) -> Result<&String> {
|
||||||
|
match self {
|
||||||
|
Self::String(v) => Ok(v),
|
||||||
|
v => crate::bail!("not a string {v:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn read<R: std::io::Read>(reader: &mut R, value_type: ValueType) -> Result<Self> {
|
fn read<R: std::io::Read>(reader: &mut R, value_type: ValueType) -> Result<Self> {
|
||||||
let v = match value_type {
|
let v = match value_type {
|
||||||
ValueType::U8 => Self::U8(reader.read_u8()?),
|
ValueType::U8 => Self::U8(reader.read_u8()?),
|
||||||
@ -225,4 +299,16 @@ impl Content {
|
|||||||
tensor_data_offset,
|
tensor_data_offset,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn tensor<R: std::io::Seek + std::io::Read>(
|
||||||
|
&self,
|
||||||
|
reader: &mut R,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<QTensor> {
|
||||||
|
let tensor_info = match self.tensor_infos.get(name) {
|
||||||
|
Some(tensor_info) => tensor_info,
|
||||||
|
None => crate::bail!("cannot find tensor-infor for {name}"),
|
||||||
|
};
|
||||||
|
tensor_info.read(reader, self.tensor_data_offset)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,8 +10,8 @@ use std::collections::HashMap;
|
|||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
use candle::quantized::ggml_file::Content;
|
|
||||||
use candle::quantized::QTensor;
|
use candle::quantized::QTensor;
|
||||||
|
use candle::quantized::{ggml_file, gguf_file};
|
||||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{Embedding, Module};
|
use candle_nn::{Embedding, Module};
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
@ -195,24 +195,26 @@ impl WeightMap {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn precomput_freqs_cis(head_dim: usize) -> Result<(Tensor, Tensor)> {
|
||||||
|
let theta: Vec<_> = (0..head_dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
|
||||||
|
.collect();
|
||||||
|
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
|
||||||
|
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.reshape((MAX_SEQ_LEN, 1))?
|
||||||
|
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||||
|
let cos = idx_theta.cos()?;
|
||||||
|
let sin = idx_theta.sin()?;
|
||||||
|
Ok((cos, sin))
|
||||||
|
}
|
||||||
|
|
||||||
impl ModelWeights {
|
impl ModelWeights {
|
||||||
fn new(mut ct: Content, gqa: usize) -> Result<Self> {
|
fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
||||||
let cpu = &Device::Cpu;
|
let cpu = &Device::Cpu;
|
||||||
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
||||||
|
let (cos, sin) = precomput_freqs_cis(head_dim)?;
|
||||||
// precompute freqs_cis
|
|
||||||
let theta: Vec<_> = (0..head_dim)
|
|
||||||
.step_by(2)
|
|
||||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
|
|
||||||
.collect();
|
|
||||||
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
|
|
||||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
|
|
||||||
.to_dtype(DType::F32)?
|
|
||||||
.reshape((MAX_SEQ_LEN, 1))?
|
|
||||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
|
||||||
let cos = idx_theta.cos()?;
|
|
||||||
let sin = idx_theta.sin()?;
|
|
||||||
|
|
||||||
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
||||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||||
let norm = RmsNorm::new(ct.remove("norm.weight")?)?;
|
let norm = RmsNorm::new(ct.remove("norm.weight")?)?;
|
||||||
@ -220,7 +222,7 @@ impl ModelWeights {
|
|||||||
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
|
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
|
||||||
for layer_idx in 0..ct.hparams.n_layer {
|
for layer_idx in 0..ct.hparams.n_layer {
|
||||||
let prefix = format!("layers.{layer_idx}");
|
let prefix = format!("layers.{layer_idx}");
|
||||||
let attention_wq = ct.remove(&format!("layers.{layer_idx}.attention.wq.weight"))?;
|
let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?;
|
||||||
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
|
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
|
||||||
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
|
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
|
||||||
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
|
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
|
||||||
@ -266,6 +268,78 @@ impl ModelWeights {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||||
|
ct: gguf_file::Content,
|
||||||
|
reader: &mut R,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let cpu = &Device::Cpu;
|
||||||
|
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||||
|
None => candle::bail!("cannot find {s} in metadata"),
|
||||||
|
Some(v) => Ok(v),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parameter extraction from metadata.
|
||||||
|
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
|
||||||
|
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
|
||||||
|
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
|
||||||
|
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
|
||||||
|
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
|
||||||
|
|
||||||
|
let (cos, sin) = precomput_freqs_cis(rope_dim)?;
|
||||||
|
|
||||||
|
let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
|
||||||
|
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||||
|
let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?)?;
|
||||||
|
let output = ct.tensor(reader, "output.weight")?;
|
||||||
|
let mut layers = Vec::with_capacity(block_count);
|
||||||
|
for layer_idx in 0..block_count {
|
||||||
|
let prefix = format!("blk.{layer_idx}");
|
||||||
|
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?;
|
||||||
|
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
|
||||||
|
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
|
||||||
|
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
|
||||||
|
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
|
||||||
|
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
|
||||||
|
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
|
||||||
|
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
|
||||||
|
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
|
||||||
|
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||||
|
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||||
|
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||||
|
layers.push(LayerWeights {
|
||||||
|
attention_wq: QMatMul::from_qtensor(attention_wq),
|
||||||
|
attention_wk: QMatMul::from_qtensor(attention_wk),
|
||||||
|
attention_wv: QMatMul::from_qtensor(attention_wv),
|
||||||
|
attention_wo: QMatMul::from_qtensor(attention_wo),
|
||||||
|
attention_norm: RmsNorm::new(attention_norm)?,
|
||||||
|
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
|
||||||
|
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
|
||||||
|
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
|
||||||
|
ffn_norm: RmsNorm::new(ffn_norm)?,
|
||||||
|
n_head: head_count,
|
||||||
|
n_kv_head: head_count_kv,
|
||||||
|
head_dim: embedding_length / head_count,
|
||||||
|
cos: cos.clone(),
|
||||||
|
sin: sin.clone(),
|
||||||
|
kv_cache: None,
|
||||||
|
span_attn,
|
||||||
|
span_rot,
|
||||||
|
span_mlp,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "model");
|
||||||
|
let span_output = tracing::span!(tracing::Level::TRACE, "output");
|
||||||
|
Ok(Self {
|
||||||
|
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
|
||||||
|
layers,
|
||||||
|
norm,
|
||||||
|
output: QMatMul::from_qtensor(output),
|
||||||
|
masks: HashMap::new(),
|
||||||
|
span,
|
||||||
|
span_output,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn mask(&mut self, t: usize) -> Result<Tensor> {
|
fn mask(&mut self, t: usize) -> Result<Tensor> {
|
||||||
if let Some(mask) = self.masks.get(&t) {
|
if let Some(mask) = self.masks.get(&t) {
|
||||||
Ok(mask.clone())
|
Ok(mask.clone())
|
||||||
@ -443,6 +517,18 @@ fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Resul
|
|||||||
Tensor::from_vec(logits, logits_len, &Device::Cpu)
|
Tensor::from_vec(logits, logits_len, &Device::Cpu)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn format_size(size_in_bytes: usize) -> String {
|
||||||
|
if size_in_bytes < 1_000 {
|
||||||
|
format!("{}B", size_in_bytes)
|
||||||
|
} else if size_in_bytes < 1_000_000 {
|
||||||
|
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
|
||||||
|
} else if size_in_bytes < 1_000_000_000 {
|
||||||
|
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
|
||||||
|
} else {
|
||||||
|
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
@ -464,37 +550,49 @@ fn main() -> anyhow::Result<()> {
|
|||||||
candle::utils::with_f16c()
|
candle::utils::with_f16c()
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut file = std::fs::File::open(&args.model()?)?;
|
let model_path = args.model()?;
|
||||||
|
let mut file = std::fs::File::open(&model_path)?;
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let model = Content::read(&mut file)?;
|
|
||||||
|
|
||||||
let mut total_size_in_bytes = 0;
|
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
|
||||||
for (_, tensor) in model.tensors.iter() {
|
Some("gguf") => {
|
||||||
let elem_count = tensor.shape().elem_count();
|
let model = gguf_file::Content::read(&mut file)?;
|
||||||
total_size_in_bytes += elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
|
let mut total_size_in_bytes = 0;
|
||||||
}
|
for (_, tensor) in model.tensor_infos.iter() {
|
||||||
let total_size = if total_size_in_bytes < 1_000 {
|
let elem_count = tensor.shape.elem_count();
|
||||||
format!("{}B", total_size_in_bytes)
|
total_size_in_bytes +=
|
||||||
} else if total_size_in_bytes < 1_000_000 {
|
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
|
||||||
format!("{:.2}KB", total_size_in_bytes as f64 / 1e3)
|
}
|
||||||
} else if total_size_in_bytes < 1_000_000_000 {
|
println!(
|
||||||
format!("{:.2}MB", total_size_in_bytes as f64 / 1e6)
|
"loaded {:?} tensors ({}) in {:.2}s",
|
||||||
} else {
|
model.tensor_infos.len(),
|
||||||
format!("{:.2}GB", total_size_in_bytes as f64 / 1e9)
|
&format_size(total_size_in_bytes),
|
||||||
|
start.elapsed().as_secs_f32(),
|
||||||
|
);
|
||||||
|
ModelWeights::from_gguf(model, &mut file)?
|
||||||
|
}
|
||||||
|
Some("ggml" | "bin") | Some(_) | None => {
|
||||||
|
let model = ggml_file::Content::read(&mut file)?;
|
||||||
|
let mut total_size_in_bytes = 0;
|
||||||
|
for (_, tensor) in model.tensors.iter() {
|
||||||
|
let elem_count = tensor.shape().elem_count();
|
||||||
|
total_size_in_bytes +=
|
||||||
|
elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"loaded {:?} tensors ({}) in {:.2}s",
|
||||||
|
model.tensors.len(),
|
||||||
|
&format_size(total_size_in_bytes),
|
||||||
|
start.elapsed().as_secs_f32(),
|
||||||
|
);
|
||||||
|
println!("params: {:?}", model.hparams);
|
||||||
|
let default_gqa = match args.which {
|
||||||
|
Which::L7b | Which::L13b => 1,
|
||||||
|
Which::L70b => 8,
|
||||||
|
};
|
||||||
|
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
println!(
|
|
||||||
"loaded {:?} tensors ({}) in {:.2}s",
|
|
||||||
model.tensors.len(),
|
|
||||||
total_size,
|
|
||||||
start.elapsed().as_secs_f32(),
|
|
||||||
);
|
|
||||||
println!("params: {:?}", model.hparams);
|
|
||||||
let default_gqa = match args.which {
|
|
||||||
Which::L7b | Which::L13b => 1,
|
|
||||||
Which::L70b => 8,
|
|
||||||
};
|
|
||||||
let mut model = ModelWeights::new(model, args.gqa.unwrap_or(default_gqa))?;
|
|
||||||
println!("model built");
|
println!("model built");
|
||||||
|
|
||||||
let tokenizer = args.tokenizer()?;
|
let tokenizer = args.tokenizer()?;
|
||||||
|
Reference in New Issue
Block a user