diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs index bf4971c0..17a60b79 100644 --- a/candle-core/src/quantized/gguf_file.rs +++ b/candle-core/src/quantized/gguf_file.rs @@ -4,7 +4,7 @@ use super::{GgmlDType, QTensor}; use crate::Result; -use byteorder::{LittleEndian, ReadBytesExt}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::collections::HashMap; pub const DEFAULT_ALIGNMENT: u64 = 32; @@ -19,7 +19,7 @@ impl TryFrom for Magic { fn try_from(value: u32) -> Result { let magic = match value { 0x46554747 | 0x47475546 => Self::Gguf, - _ => crate::bail!("unknown magic {value:08x}"), + _ => crate::bail!("unknown magic 0x{value:08x}"), }; Ok(magic) } @@ -86,7 +86,7 @@ fn read_string(reader: &mut R) -> Result { Ok(String::from_utf8_lossy(&v).into_owned()) } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum ValueType { // The value is a 8-bit unsigned integer. U8, @@ -129,6 +129,21 @@ pub enum Value { } impl Value { + pub fn value_type(&self) -> ValueType { + match self { + Self::U8(_) => ValueType::U8, + Self::I8(_) => ValueType::I8, + Self::U16(_) => ValueType::U16, + Self::I16(_) => ValueType::I16, + Self::U32(_) => ValueType::U32, + Self::I32(_) => ValueType::I32, + Self::F32(_) => ValueType::F32, + Self::Bool(_) => ValueType::Bool, + Self::String(_) => ValueType::String, + Self::Array(_) => ValueType::Array, + } + } + pub fn to_u8(&self) -> Result { match self { Self::U8(v) => Ok(*v), @@ -227,6 +242,41 @@ impl Value { }; Ok(v) } + + fn write(&self, w: &mut W) -> Result<()> { + match self { + &Self::U8(v) => w.write_u8(v)?, + &Self::I8(v) => w.write_i8(v)?, + &Self::U16(v) => w.write_u16::(v)?, + &Self::I16(v) => w.write_i16::(v)?, + &Self::U32(v) => w.write_u32::(v)?, + &Self::I32(v) => w.write_i32::(v)?, + &Self::F32(v) => w.write_f32::(v)?, + &Self::Bool(v) => w.write_u8(u8::from(v))?, + Self::String(v) => write_string(w, v.as_str())?, + Self::Array(v) => { + // The `Value` type does not enforce that all the values in an Array have the same + // type. + let value_type = if v.is_empty() { + // Doesn't matter, the array is empty. + ValueType::U32 + } else { + let value_type: std::collections::HashSet<_> = + v.iter().map(|elem| elem.value_type()).collect(); + if value_type.len() != 1 { + crate::bail!("multiple value-types in the same array {value_type:?}") + } + value_type.into_iter().next().unwrap() + }; + w.write_u32::(value_type.to_u32())?; + w.write_u32::(v.len() as u32)?; + for elem in v.iter() { + elem.write(w)? + } + } + } + Ok(()) + } } impl ValueType { @@ -246,6 +296,21 @@ impl ValueType { }; Ok(v) } + + fn to_u32(self) -> u32 { + match self { + Self::U8 => 0, + Self::I8 => 1, + Self::U16 => 2, + Self::I16 => 3, + Self::U32 => 4, + Self::I32 => 5, + Self::F32 => 6, + Self::Bool => 7, + Self::String => 8, + Self::Array => 9, + } + } } impl Content { @@ -312,3 +377,60 @@ impl Content { tensor_info.read(reader, self.tensor_data_offset) } } + +fn write_string(w: &mut W, str: &str) -> Result<()> { + let bytes = str.as_bytes(); + w.write_u32::(bytes.len() as u32)?; + w.write_all(bytes)?; + Ok(()) +} + +pub fn write( + w: &mut W, + metadata: &[(&str, &Value)], + tensors: &[(&str, &QTensor)], +) -> Result<()> { + w.write_u32::(0x46554747)?; + w.write_u32::(1)?; // version 1. + w.write_u32::(tensors.len() as u32)?; + for (name, value) in metadata.iter() { + write_string(w, name)?; + w.write_u32::(value.value_type().to_u32())?; + value.write(w)?; + } + let mut offset = 0usize; + let mut offsets = Vec::with_capacity(tensors.len()); + for (name, tensor) in tensors.iter() { + write_string(w, name)?; + let dims = tensor.shape().dims(); + w.write_u32::(dims.len() as u32)?; + for &dim in dims.iter().rev() { + w.write_u32::(dim as u32)?; + } + w.write_u32::(tensor.dtype().to_u32())?; + w.write_u64::(offset as u64)?; + offsets.push(offset); + let size_in_bytes = tensor.storage_size_in_bytes(); + let padding = 31 - (31 + size_in_bytes) % 32; + offset += size_in_bytes + padding; + } + let pos = w.stream_position()? as usize; + let padding = 31 - (31 + pos) % 32; + w.write_all(&vec![0u8; padding])?; + let tensor_start_pos = w.stream_position()? as usize; + for (offset, (_name, tensor)) in offsets.iter().zip(tensors.iter()) { + let pos = w.stream_position()? as usize; + if tensor_start_pos + offset != pos { + crate::bail!( + "internal error, unexpected current position {tensor_start_pos} {offset} {pos}" + ) + } + let data_ptr = tensor.as_ptr(); + let size_in_bytes = tensor.storage_size_in_bytes(); + let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) }; + w.write_all(data)?; + let padding = 31 - (31 + size_in_bytes) % 32; + w.write_all(&vec![0u8; padding])?; + } + Ok(()) +} diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 568cd9ad..cb788779 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -16,7 +16,7 @@ pub struct QTensor { shape: Shape, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum GgmlDType { F32, F16, @@ -56,6 +56,25 @@ impl GgmlDType { Ok(dtype) } + pub(crate) fn to_u32(self) -> u32 { + match self { + Self::F32 => 0, + Self::F16 => 1, + Self::Q4_0 => 2, + Self::Q4_1 => 3, + Self::Q5_0 => 6, + Self::Q5_1 => 7, + Self::Q8_0 => 8, + Self::Q8_1 => 9, + Self::Q2K => 10, + Self::Q3K => 11, + Self::Q4K => 12, + Self::Q5K => 13, + Self::Q6K => 14, + Self::Q8K => 15, + } + } + /// The type size for blocks in bytes. pub fn type_size(&self) -> usize { use k_quants::*; @@ -99,6 +118,8 @@ pub trait QuantizedType: Send + Sync { fn dtype(&self) -> GgmlDType; fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>; fn to_float(&self, ys: &mut [f32]) -> Result<()>; + fn storage_size_in_bytes(&self) -> usize; + fn as_ptr(&self) -> *const u8; } impl QuantizedType for Vec { @@ -113,6 +134,14 @@ impl QuantizedType for Vec { fn to_float(&self, ys: &mut [f32]) -> Result<()> { T::to_float(self.as_slice(), ys) } + + fn storage_size_in_bytes(&self) -> usize { + self.len() * std::mem::size_of::() + } + + fn as_ptr(&self) -> *const u8 { + self.as_ptr() as *const u8 + } } impl std::fmt::Debug for QTensor { @@ -186,6 +215,14 @@ impl QTensor { pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> { self.data.matmul_t(mkn, lhs, dst) } + + pub fn storage_size_in_bytes(&self) -> usize { + self.data.storage_size_in_bytes() + } + + pub fn as_ptr(&self) -> *const u8 { + self.data.as_ptr() + } } #[derive(Debug)]