GGUF support in the quantized model. (#559)

* GGUF support in the quantized model.

* Get the GGUF support to work on llama.
This commit is contained in:
Laurent Mazare
2023-08-23 09:20:57 +01:00
committed by GitHub
parent 0764741cc4
commit 508d34daf2
2 changed files with 231 additions and 47 deletions

View File

@ -2,7 +2,7 @@
//! //!
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md //! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
use super::GgmlDType; use super::{GgmlDType, QTensor};
use crate::Result; use crate::Result;
use byteorder::{LittleEndian, ReadBytesExt}; use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap; use std::collections::HashMap;
@ -55,7 +55,7 @@ impl TensorInfo {
&self, &self,
reader: &mut R, reader: &mut R,
tensor_data_offset: u64, tensor_data_offset: u64,
) -> Result<super::QTensor> { ) -> Result<QTensor> {
let tensor_elems = self.shape.elem_count(); let tensor_elems = self.shape.elem_count();
let size_in_bytes = let size_in_bytes =
tensor_elems * self.ggml_dtype.type_size() / self.ggml_dtype.blck_size(); tensor_elems * self.ggml_dtype.type_size() / self.ggml_dtype.blck_size();
@ -78,6 +78,10 @@ fn read_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
let len = reader.read_u32::<LittleEndian>()?; let len = reader.read_u32::<LittleEndian>()?;
let mut v = vec![0u8; len as usize]; let mut v = vec![0u8; len as usize];
reader.read_exact(&mut v)?; reader.read_exact(&mut v)?;
// GGUF strings are supposed to be non-null terminated but in practice this happens.
while let Some(0) = v.last() {
v.pop();
}
// GGUF strings are utf8 encoded but there are cases that don't seem to be valid. // GGUF strings are utf8 encoded but there are cases that don't seem to be valid.
Ok(String::from_utf8_lossy(&v).into_owned()) Ok(String::from_utf8_lossy(&v).into_owned())
} }
@ -125,6 +129,76 @@ pub enum Value {
} }
impl Value { impl Value {
pub fn to_u8(&self) -> Result<u8> {
match self {
Self::U8(v) => Ok(*v),
v => crate::bail!("not a u8 {v:?}"),
}
}
pub fn to_i8(&self) -> Result<i8> {
match self {
Self::I8(v) => Ok(*v),
v => crate::bail!("not a i8 {v:?}"),
}
}
pub fn to_u16(&self) -> Result<u16> {
match self {
Self::U16(v) => Ok(*v),
v => crate::bail!("not a u16 {v:?}"),
}
}
pub fn to_i16(&self) -> Result<i16> {
match self {
Self::I16(v) => Ok(*v),
v => crate::bail!("not a i16 {v:?}"),
}
}
pub fn to_u32(&self) -> Result<u32> {
match self {
Self::U32(v) => Ok(*v),
v => crate::bail!("not a u32 {v:?}"),
}
}
pub fn to_i32(&self) -> Result<i32> {
match self {
Self::I32(v) => Ok(*v),
v => crate::bail!("not a i32 {v:?}"),
}
}
pub fn to_f32(&self) -> Result<f32> {
match self {
Self::F32(v) => Ok(*v),
v => crate::bail!("not a f32 {v:?}"),
}
}
pub fn to_bool(&self) -> Result<bool> {
match self {
Self::Bool(v) => Ok(*v),
v => crate::bail!("not a bool {v:?}"),
}
}
pub fn to_vec(&self) -> Result<&Vec<Value>> {
match self {
Self::Array(v) => Ok(v),
v => crate::bail!("not a vec {v:?}"),
}
}
pub fn to_string(&self) -> Result<&String> {
match self {
Self::String(v) => Ok(v),
v => crate::bail!("not a string {v:?}"),
}
}
fn read<R: std::io::Read>(reader: &mut R, value_type: ValueType) -> Result<Self> { fn read<R: std::io::Read>(reader: &mut R, value_type: ValueType) -> Result<Self> {
let v = match value_type { let v = match value_type {
ValueType::U8 => Self::U8(reader.read_u8()?), ValueType::U8 => Self::U8(reader.read_u8()?),
@ -225,4 +299,16 @@ impl Content {
tensor_data_offset, tensor_data_offset,
}) })
} }
pub fn tensor<R: std::io::Seek + std::io::Read>(
&self,
reader: &mut R,
name: &str,
) -> Result<QTensor> {
let tensor_info = match self.tensor_infos.get(name) {
Some(tensor_info) => tensor_info,
None => crate::bail!("cannot find tensor-infor for {name}"),
};
tensor_info.read(reader, self.tensor_data_offset)
}
} }

View File

@ -10,8 +10,8 @@ use std::collections::HashMap;
use std::io::Write; use std::io::Write;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use candle::quantized::ggml_file::Content;
use candle::quantized::QTensor; use candle::quantized::QTensor;
use candle::quantized::{ggml_file, gguf_file};
use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module}; use candle_nn::{Embedding, Module};
use candle_transformers::generation::LogitsProcessor; use candle_transformers::generation::LogitsProcessor;
@ -195,24 +195,26 @@ impl WeightMap {
} }
} }
fn precomput_freqs_cis(head_dim: usize) -> Result<(Tensor, Tensor)> {
let theta: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
Ok((cos, sin))
}
impl ModelWeights { impl ModelWeights {
fn new(mut ct: Content, gqa: usize) -> Result<Self> { fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
let cpu = &Device::Cpu; let cpu = &Device::Cpu;
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
let (cos, sin) = precomput_freqs_cis(head_dim)?;
// precompute freqs_cis
let theta: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
let tok_embeddings = ct.remove("tok_embeddings.weight")?; let tok_embeddings = ct.remove("tok_embeddings.weight")?;
let tok_embeddings = tok_embeddings.dequantize(cpu)?; let tok_embeddings = tok_embeddings.dequantize(cpu)?;
let norm = RmsNorm::new(ct.remove("norm.weight")?)?; let norm = RmsNorm::new(ct.remove("norm.weight")?)?;
@ -220,7 +222,7 @@ impl ModelWeights {
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
for layer_idx in 0..ct.hparams.n_layer { for layer_idx in 0..ct.hparams.n_layer {
let prefix = format!("layers.{layer_idx}"); let prefix = format!("layers.{layer_idx}");
let attention_wq = ct.remove(&format!("layers.{layer_idx}.attention.wq.weight"))?; let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?;
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?; let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?; let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?; let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
@ -266,6 +268,78 @@ impl ModelWeights {
}) })
} }
fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
) -> Result<Self> {
let cpu = &Device::Cpu;
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
};
// Parameter extraction from metadata.
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
let (cos, sin) = precomput_freqs_cis(rope_dim)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?)?;
let output = ct.tensor(reader, "output.weight")?;
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}");
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
layers.push(LayerWeights {
attention_wq: QMatMul::from_qtensor(attention_wq),
attention_wk: QMatMul::from_qtensor(attention_wk),
attention_wv: QMatMul::from_qtensor(attention_wv),
attention_wo: QMatMul::from_qtensor(attention_wo),
attention_norm: RmsNorm::new(attention_norm)?,
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
ffn_norm: RmsNorm::new(ffn_norm)?,
n_head: head_count,
n_kv_head: head_count_kv,
head_dim: embedding_length / head_count,
cos: cos.clone(),
sin: sin.clone(),
kv_cache: None,
span_attn,
span_rot,
span_mlp,
})
}
let span = tracing::span!(tracing::Level::TRACE, "model");
let span_output = tracing::span!(tracing::Level::TRACE, "output");
Ok(Self {
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
layers,
norm,
output: QMatMul::from_qtensor(output),
masks: HashMap::new(),
span,
span_output,
})
}
fn mask(&mut self, t: usize) -> Result<Tensor> { fn mask(&mut self, t: usize) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&t) { if let Some(mask) = self.masks.get(&t) {
Ok(mask.clone()) Ok(mask.clone())
@ -443,6 +517,18 @@ fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Resul
Tensor::from_vec(logits, logits_len, &Device::Cpu) Tensor::from_vec(logits, logits_len, &Device::Cpu)
} }
fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes)
} else if size_in_bytes < 1_000_000 {
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
} else if size_in_bytes < 1_000_000_000 {
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
} else {
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
}
}
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder; use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*; use tracing_subscriber::prelude::*;
@ -464,37 +550,49 @@ fn main() -> anyhow::Result<()> {
candle::utils::with_f16c() candle::utils::with_f16c()
); );
let mut file = std::fs::File::open(&args.model()?)?; let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let model = Content::read(&mut file)?;
let mut total_size_in_bytes = 0; let mut model = match model_path.extension().and_then(|v| v.to_str()) {
for (_, tensor) in model.tensors.iter() { Some("gguf") => {
let elem_count = tensor.shape().elem_count(); let model = gguf_file::Content::read(&mut file)?;
total_size_in_bytes += elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size(); let mut total_size_in_bytes = 0;
} for (_, tensor) in model.tensor_infos.iter() {
let total_size = if total_size_in_bytes < 1_000 { let elem_count = tensor.shape.elem_count();
format!("{}B", total_size_in_bytes) total_size_in_bytes +=
} else if total_size_in_bytes < 1_000_000 { elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
format!("{:.2}KB", total_size_in_bytes as f64 / 1e3) }
} else if total_size_in_bytes < 1_000_000_000 { println!(
format!("{:.2}MB", total_size_in_bytes as f64 / 1e6) "loaded {:?} tensors ({}) in {:.2}s",
} else { model.tensor_infos.len(),
format!("{:.2}GB", total_size_in_bytes as f64 / 1e9) &format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
ModelWeights::from_gguf(model, &mut file)?
}
Some("ggml" | "bin") | Some(_) | None => {
let model = ggml_file::Content::read(&mut file)?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensors.iter() {
let elem_count = tensor.shape().elem_count();
total_size_in_bytes +=
elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensors.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
println!("params: {:?}", model.hparams);
let default_gqa = match args.which {
Which::L7b | Which::L13b => 1,
Which::L70b => 8,
};
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
}
}; };
println!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensors.len(),
total_size,
start.elapsed().as_secs_f32(),
);
println!("params: {:?}", model.hparams);
let default_gqa = match args.which {
Which::L7b | Which::L13b => 1,
Which::L70b => 8,
};
let mut model = ModelWeights::new(model, args.gqa.unwrap_or(default_gqa))?;
println!("model built"); println!("model built");
let tokenizer = args.tokenizer()?; let tokenizer = args.tokenizer()?;