From ca449f9ee11b892e026972d114c77a0938e1dc0b Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 15 Aug 2023 22:45:53 +0100 Subject: [PATCH] Add quantized tensors. (#458) * Add quantized tensors. * Implement the debug trait for QTensor. * Add the QMatMul custom op. --- candle-core/src/quantized/ggml_file.rs | 131 +++++-------------------- candle-core/src/quantized/mod.rs | 114 ++++++++++++++++++++- candle-examples/examples/ggml/main.rs | 3 +- 3 files changed, 140 insertions(+), 108 deletions(-) diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index 2824f075..ee23cdde 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -1,7 +1,7 @@ //! Support for the GGML file format. use super::{k_quants, GgmlDType}; -use crate::{DType, Device, Result, Tensor}; +use crate::Result; use byteorder::{LittleEndian, ReadBytesExt}; // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37 @@ -116,121 +116,47 @@ impl Vocab { } } -fn dequantize_and_create_tensor( +fn from_raw_data( raw_data: &[u8], - tensor_elems: usize, size_in_bytes: usize, dims: Vec, - device: &Device, -) -> Result { - let mut f32_data = vec![0f32; tensor_elems]; +) -> Result { let raw_data_ptr = raw_data.as_ptr(); let n_blocks = size_in_bytes / std::mem::size_of::(); - let raw_data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) }; - T::to_float(raw_data, &mut f32_data)?; - Tensor::from_vec(f32_data, dims, device) + let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) }; + Ok(super::QTensor::new(data.to_vec(), dims)) } /// Creates a [Tensor] from a raw GGML tensor. -pub fn tensor_from_ggml( +pub fn qtensor_from_ggml( ggml_dtype: GgmlDType, raw_data: &[u8], dims: Vec, - dtype: DType, - device: &Device, -) -> Result { +) -> Result { let tensor_elems = dims.iter().product::(); let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size(); - let tensor = match ggml_dtype { - GgmlDType::F32 => Tensor::from_raw_buffer(raw_data, DType::F32, &dims, device), - GgmlDType::F16 => Tensor::from_raw_buffer(raw_data, DType::F16, &dims, device), - GgmlDType::Q4_0 => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q4_1 => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q5_0 => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q5_1 => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q8_0 => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q2K => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q3K => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q4K => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q5K => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q6K => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - _ => crate::bail!("quantized type {dtype:?} is not supported yet"), - }?; - //We only have ggml-quant to f32 conversions, meaning we have to convert to the desired type - if tensor.dtype() != dtype { - tensor.to_dtype(dtype) - } else { - Ok(tensor) + match ggml_dtype { + GgmlDType::F32 => from_raw_data::(raw_data, size_in_bytes, dims), + GgmlDType::F16 => from_raw_data::(raw_data, size_in_bytes, dims), + GgmlDType::Q4_0 => from_raw_data::(raw_data, size_in_bytes, dims), + GgmlDType::Q4_1 => from_raw_data::(raw_data, size_in_bytes, dims), + GgmlDType::Q5_0 => from_raw_data::(raw_data, size_in_bytes, dims), + GgmlDType::Q5_1 => from_raw_data::(raw_data, size_in_bytes, dims), + GgmlDType::Q8_0 => from_raw_data::(raw_data, size_in_bytes, dims), + GgmlDType::Q2K => from_raw_data::(raw_data, size_in_bytes, dims), + GgmlDType::Q3K => from_raw_data::(raw_data, size_in_bytes, dims), + GgmlDType::Q4K => from_raw_data::(raw_data, size_in_bytes, dims), + GgmlDType::Q5K => from_raw_data::(raw_data, size_in_bytes, dims), + GgmlDType::Q6K => from_raw_data::(raw_data, size_in_bytes, dims), + _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"), } } fn read_one_tensor( reader: &mut R, magic: VersionedMagic, - dtype: DType, - device: &Device, -) -> Result<(String, Tensor)> { +) -> Result<(String, super::QTensor)> { let n_dims = reader.read_u32::()?; let name_len = reader.read_u32::()?; let ggml_dtype = reader.read_u32::()?; @@ -252,26 +178,21 @@ fn read_one_tensor( // TODO: Mmap version to avoid copying the data around? let mut raw_data = vec![0u8; size_in_bytes]; reader.read_exact(&mut raw_data)?; - match tensor_from_ggml(ggml_dtype, &raw_data, dims, dtype, device) { + match qtensor_from_ggml(ggml_dtype, &raw_data, dims) { Ok(tensor) => Ok((name, tensor)), Err(e) => crate::bail!("Error creating tensor {name}: {e}"), } } -#[derive(Debug)] pub struct Content { pub magic: VersionedMagic, pub hparams: HParams, pub vocab: Vocab, - pub tensors: Vec<(String, Tensor)>, + pub tensors: Vec<(String, super::QTensor)>, } impl Content { - pub fn read( - reader: &mut R, - dtype: DType, - device: &Device, - ) -> Result { + pub fn read(reader: &mut R) -> Result { // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505 let last_position = reader.seek(std::io::SeekFrom::End(0))?; reader.seek(std::io::SeekFrom::Start(0))?; @@ -281,7 +202,7 @@ impl Content { let mut tensors = vec![]; while reader.stream_position()? != last_position { - let (name, tensor) = read_one_tensor(reader, magic, dtype, device)?; + let (name, tensor) = read_one_tensor(reader, magic)?; tensors.push((name, tensor)) } Ok(Self { diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index c7e24592..842b519b 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -1,10 +1,15 @@ -use crate::Result; +use crate::{Device, Result, Shape, Tensor}; pub mod ggml_file; pub mod k_quants; pub use k_quants::GgmlType; +pub struct QTensor { + data: Box, + shape: Shape, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum GgmlDType { F32, @@ -80,3 +85,110 @@ impl GgmlDType { } } } + +// A version of GgmlType without `vec_dot` so that it can be dyn boxed. +pub trait QuantizedType: Send + Sync { + fn dtype(&self) -> GgmlDType; + fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>; + fn to_float(&self, ys: &mut [f32]) -> Result<()>; +} + +impl QuantizedType for Vec { + fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> { + k_quants::matmul(mkn, lhs, self.as_slice(), dst) + } + + fn dtype(&self) -> GgmlDType { + T::DTYPE + } + + fn to_float(&self, ys: &mut [f32]) -> Result<()> { + T::to_float(self.as_slice(), ys) + } +} + +impl std::fmt::Debug for QTensor { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype()) + } +} + +impl QTensor { + pub fn new, T: k_quants::GgmlType + Send + Sync + 'static>( + data: Vec, + shape: S, + ) -> Self { + Self { + data: Box::new(data), + shape: shape.into(), + } + } + + pub fn dtype(&self) -> GgmlDType { + self.data.dtype() + } + + pub fn shape(&self) -> &Shape { + &self.shape + } + + pub fn dequantize(&self, device: &Device) -> Result { + let mut f32_data = vec![0f32; self.shape.elem_count()]; + self.data.to_float(&mut f32_data)?; + Tensor::from_vec(f32_data, &self.shape, device) + } + + pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> { + self.data.matmul_t(mkn, lhs, dst) + } +} + +#[derive(Debug, Clone)] +pub struct QMatMul(std::sync::Arc); + +impl QMatMul { + pub fn new(qtensor: std::sync::Arc) -> Self { + Self(qtensor) + } +} + +impl crate::CustomOp1 for QMatMul { + fn name(&self) -> &'static str { + "qmatmul" + } + + fn cpu_fwd( + &self, + storage: &crate::CpuStorage, + layout: &crate::Layout, + ) -> Result<(crate::CpuStorage, Shape)> { + if !layout.is_contiguous() { + crate::bail!("input tensor is not contiguous {layout:?}") + } + let src_shape = layout.shape(); + let (k, n) = self.0.shape.dims2()?; + if src_shape.rank() < 2 { + crate::bail!("input tensor has only one dimension {layout:?}") + } + let mut dst_shape = src_shape.dims().to_vec(); + let last_k = dst_shape.pop().unwrap(); + if last_k != k { + crate::bail!( + "input tensor {layout:?} incompatible with {:?}", + self.0.shape + ) + } + dst_shape.push(n); + let dst_shape = Shape::from(dst_shape); + let storage = storage.as_slice::()?; + let storage = + &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; + let mut dst_storage = vec![0f32; dst_shape.elem_count()]; + self.0.matmul_t( + (dst_shape.elem_count() / n, k, n), + storage, + &mut dst_storage, + )?; + Ok((crate::CpuStorage::F32(dst_storage), dst_shape)) + } +} diff --git a/candle-examples/examples/ggml/main.rs b/candle-examples/examples/ggml/main.rs index 78eb20dc..9e3e1ba6 100644 --- a/candle-examples/examples/ggml/main.rs +++ b/candle-examples/examples/ggml/main.rs @@ -3,7 +3,6 @@ use clap::Parser; use std::fs::File; use candle::quantized::ggml_file::Content; -use candle::{DType, Device}; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] @@ -18,7 +17,7 @@ fn main() -> Result<()> { let mut file = File::open(args.model)?; let start = std::time::Instant::now(); - let model = Content::read(&mut file, DType::F16, &Device::Cpu)?; + let model = Content::read(&mut file)?; println!( "Loaded {:?} tensors in {:?}",