From 508d34daf291589058929df8d91565cfce9c4375 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 23 Aug 2023 09:20:57 +0100 Subject: [PATCH] GGUF support in the quantized model. (#559) * GGUF support in the quantized model. * Get the GGUF support to work on llama. --- candle-core/src/quantized/gguf_file.rs | 90 +++++++++- candle-examples/examples/quantized/main.rs | 188 ++++++++++++++++----- 2 files changed, 231 insertions(+), 47 deletions(-) diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs index 3f13b7de..bf4971c0 100644 --- a/candle-core/src/quantized/gguf_file.rs +++ b/candle-core/src/quantized/gguf_file.rs @@ -2,7 +2,7 @@ //! //! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md -use super::GgmlDType; +use super::{GgmlDType, QTensor}; use crate::Result; use byteorder::{LittleEndian, ReadBytesExt}; use std::collections::HashMap; @@ -55,7 +55,7 @@ impl TensorInfo { &self, reader: &mut R, tensor_data_offset: u64, - ) -> Result { + ) -> Result { let tensor_elems = self.shape.elem_count(); let size_in_bytes = tensor_elems * self.ggml_dtype.type_size() / self.ggml_dtype.blck_size(); @@ -78,6 +78,10 @@ fn read_string(reader: &mut R) -> Result { let len = reader.read_u32::()?; let mut v = vec![0u8; len as usize]; reader.read_exact(&mut v)?; + // GGUF strings are supposed to be non-null terminated but in practice this happens. + while let Some(0) = v.last() { + v.pop(); + } // GGUF strings are utf8 encoded but there are cases that don't seem to be valid. Ok(String::from_utf8_lossy(&v).into_owned()) } @@ -125,6 +129,76 @@ pub enum Value { } impl Value { + pub fn to_u8(&self) -> Result { + match self { + Self::U8(v) => Ok(*v), + v => crate::bail!("not a u8 {v:?}"), + } + } + + pub fn to_i8(&self) -> Result { + match self { + Self::I8(v) => Ok(*v), + v => crate::bail!("not a i8 {v:?}"), + } + } + + pub fn to_u16(&self) -> Result { + match self { + Self::U16(v) => Ok(*v), + v => crate::bail!("not a u16 {v:?}"), + } + } + + pub fn to_i16(&self) -> Result { + match self { + Self::I16(v) => Ok(*v), + v => crate::bail!("not a i16 {v:?}"), + } + } + + pub fn to_u32(&self) -> Result { + match self { + Self::U32(v) => Ok(*v), + v => crate::bail!("not a u32 {v:?}"), + } + } + + pub fn to_i32(&self) -> Result { + match self { + Self::I32(v) => Ok(*v), + v => crate::bail!("not a i32 {v:?}"), + } + } + + pub fn to_f32(&self) -> Result { + match self { + Self::F32(v) => Ok(*v), + v => crate::bail!("not a f32 {v:?}"), + } + } + + pub fn to_bool(&self) -> Result { + match self { + Self::Bool(v) => Ok(*v), + v => crate::bail!("not a bool {v:?}"), + } + } + + pub fn to_vec(&self) -> Result<&Vec> { + match self { + Self::Array(v) => Ok(v), + v => crate::bail!("not a vec {v:?}"), + } + } + + pub fn to_string(&self) -> Result<&String> { + match self { + Self::String(v) => Ok(v), + v => crate::bail!("not a string {v:?}"), + } + } + fn read(reader: &mut R, value_type: ValueType) -> Result { let v = match value_type { ValueType::U8 => Self::U8(reader.read_u8()?), @@ -225,4 +299,16 @@ impl Content { tensor_data_offset, }) } + + pub fn tensor( + &self, + reader: &mut R, + name: &str, + ) -> Result { + let tensor_info = match self.tensor_infos.get(name) { + Some(tensor_info) => tensor_info, + None => crate::bail!("cannot find tensor-infor for {name}"), + }; + tensor_info.read(reader, self.tensor_data_offset) + } } diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 477c695f..7c457f7a 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -10,8 +10,8 @@ use std::collections::HashMap; use std::io::Write; use tokenizers::Tokenizer; -use candle::quantized::ggml_file::Content; use candle::quantized::QTensor; +use candle::quantized::{ggml_file, gguf_file}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{Embedding, Module}; use candle_transformers::generation::LogitsProcessor; @@ -195,24 +195,26 @@ impl WeightMap { } } +fn precomput_freqs_cis(head_dim: usize) -> Result<(Tensor, Tensor)> { + let theta: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok((cos, sin)) +} + impl ModelWeights { - fn new(mut ct: Content, gqa: usize) -> Result { + fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result { let cpu = &Device::Cpu; let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; - - // precompute freqs_cis - let theta: Vec<_> = (0..head_dim) - .step_by(2) - .map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32)) - .collect(); - let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?; - let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)? - .to_dtype(DType::F32)? - .reshape((MAX_SEQ_LEN, 1))? - .matmul(&theta.reshape((1, theta.elem_count()))?)?; - let cos = idx_theta.cos()?; - let sin = idx_theta.sin()?; - + let (cos, sin) = precomput_freqs_cis(head_dim)?; let tok_embeddings = ct.remove("tok_embeddings.weight")?; let tok_embeddings = tok_embeddings.dequantize(cpu)?; let norm = RmsNorm::new(ct.remove("norm.weight")?)?; @@ -220,7 +222,7 @@ impl ModelWeights { let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); for layer_idx in 0..ct.hparams.n_layer { let prefix = format!("layers.{layer_idx}"); - let attention_wq = ct.remove(&format!("layers.{layer_idx}.attention.wq.weight"))?; + let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?; let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?; let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?; let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?; @@ -266,6 +268,78 @@ impl ModelWeights { }) } + fn from_gguf( + ct: gguf_file::Content, + reader: &mut R, + ) -> Result { + let cpu = &Device::Cpu; + let md_get = |s: &str| match ct.metadata.get(s) { + None => candle::bail!("cannot find {s} in metadata"), + Some(v) => Ok(v), + }; + + // Parameter extraction from metadata. + let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; + let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; + let block_count = md_get("llama.block_count")?.to_u32()? as usize; + let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; + let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; + + let (cos, sin) = precomput_freqs_cis(rope_dim)?; + + let tok_embeddings = ct.tensor(reader, "token_embd.weight")?; + let tok_embeddings = tok_embeddings.dequantize(cpu)?; + let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?)?; + let output = ct.tensor(reader, "output.weight")?; + let mut layers = Vec::with_capacity(block_count); + for layer_idx in 0..block_count { + let prefix = format!("blk.{layer_idx}"); + let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?; + let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?; + let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?; + let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?; + let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?; + let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?; + let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?; + let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?; + let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?; + let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); + layers.push(LayerWeights { + attention_wq: QMatMul::from_qtensor(attention_wq), + attention_wk: QMatMul::from_qtensor(attention_wk), + attention_wv: QMatMul::from_qtensor(attention_wv), + attention_wo: QMatMul::from_qtensor(attention_wo), + attention_norm: RmsNorm::new(attention_norm)?, + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), + ffn_norm: RmsNorm::new(ffn_norm)?, + n_head: head_count, + n_kv_head: head_count_kv, + head_dim: embedding_length / head_count, + cos: cos.clone(), + sin: sin.clone(), + kv_cache: None, + span_attn, + span_rot, + span_mlp, + }) + } + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, embedding_length), + layers, + norm, + output: QMatMul::from_qtensor(output), + masks: HashMap::new(), + span, + span_output, + }) + } + fn mask(&mut self, t: usize) -> Result { if let Some(mask) = self.masks.get(&t) { Ok(mask.clone()) @@ -443,6 +517,18 @@ fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Resul Tensor::from_vec(logits, logits_len, &Device::Cpu) } +fn format_size(size_in_bytes: usize) -> String { + if size_in_bytes < 1_000 { + format!("{}B", size_in_bytes) + } else if size_in_bytes < 1_000_000 { + format!("{:.2}KB", size_in_bytes as f64 / 1e3) + } else if size_in_bytes < 1_000_000_000 { + format!("{:.2}MB", size_in_bytes as f64 / 1e6) + } else { + format!("{:.2}GB", size_in_bytes as f64 / 1e9) + } +} + fn main() -> anyhow::Result<()> { use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::prelude::*; @@ -464,37 +550,49 @@ fn main() -> anyhow::Result<()> { candle::utils::with_f16c() ); - let mut file = std::fs::File::open(&args.model()?)?; + let model_path = args.model()?; + let mut file = std::fs::File::open(&model_path)?; let start = std::time::Instant::now(); - let model = Content::read(&mut file)?; - let mut total_size_in_bytes = 0; - for (_, tensor) in model.tensors.iter() { - let elem_count = tensor.shape().elem_count(); - total_size_in_bytes += elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size(); - } - let total_size = if total_size_in_bytes < 1_000 { - format!("{}B", total_size_in_bytes) - } else if total_size_in_bytes < 1_000_000 { - format!("{:.2}KB", total_size_in_bytes as f64 / 1e3) - } else if total_size_in_bytes < 1_000_000_000 { - format!("{:.2}MB", total_size_in_bytes as f64 / 1e6) - } else { - format!("{:.2}GB", total_size_in_bytes as f64 / 1e9) + let mut model = match model_path.extension().and_then(|v| v.to_str()) { + Some("gguf") => { + let model = gguf_file::Content::read(&mut file)?; + let mut total_size_in_bytes = 0; + for (_, tensor) in model.tensor_infos.iter() { + let elem_count = tensor.shape.elem_count(); + total_size_in_bytes += + elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size(); + } + println!( + "loaded {:?} tensors ({}) in {:.2}s", + model.tensor_infos.len(), + &format_size(total_size_in_bytes), + start.elapsed().as_secs_f32(), + ); + ModelWeights::from_gguf(model, &mut file)? + } + Some("ggml" | "bin") | Some(_) | None => { + let model = ggml_file::Content::read(&mut file)?; + let mut total_size_in_bytes = 0; + for (_, tensor) in model.tensors.iter() { + let elem_count = tensor.shape().elem_count(); + total_size_in_bytes += + elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size(); + } + println!( + "loaded {:?} tensors ({}) in {:.2}s", + model.tensors.len(), + &format_size(total_size_in_bytes), + start.elapsed().as_secs_f32(), + ); + println!("params: {:?}", model.hparams); + let default_gqa = match args.which { + Which::L7b | Which::L13b => 1, + Which::L70b => 8, + }; + ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))? + } }; - - println!( - "loaded {:?} tensors ({}) in {:.2}s", - model.tensors.len(), - total_size, - start.elapsed().as_secs_f32(), - ); - println!("params: {:?}", model.hparams); - let default_gqa = match args.which { - Which::L7b | Which::L13b => 1, - Which::L70b => 8, - }; - let mut model = ModelWeights::new(model, args.gqa.unwrap_or(default_gqa))?; println!("model built"); let tokenizer = args.tokenizer()?;