mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add a function to write gguf files. (#585)
* Add a function to write gguf files. * More GGUF file writing. * Write the tensor data in GGUF files.
This commit is contained in:
@ -4,7 +4,7 @@
|
||||
|
||||
use super::{GgmlDType, QTensor};
|
||||
use crate::Result;
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub const DEFAULT_ALIGNMENT: u64 = 32;
|
||||
@ -19,7 +19,7 @@ impl TryFrom<u32> for Magic {
|
||||
fn try_from(value: u32) -> Result<Self> {
|
||||
let magic = match value {
|
||||
0x46554747 | 0x47475546 => Self::Gguf,
|
||||
_ => crate::bail!("unknown magic {value:08x}"),
|
||||
_ => crate::bail!("unknown magic 0x{value:08x}"),
|
||||
};
|
||||
Ok(magic)
|
||||
}
|
||||
@ -86,7 +86,7 @@ fn read_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
|
||||
Ok(String::from_utf8_lossy(&v).into_owned())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum ValueType {
|
||||
// The value is a 8-bit unsigned integer.
|
||||
U8,
|
||||
@ -129,6 +129,21 @@ pub enum Value {
|
||||
}
|
||||
|
||||
impl Value {
|
||||
pub fn value_type(&self) -> ValueType {
|
||||
match self {
|
||||
Self::U8(_) => ValueType::U8,
|
||||
Self::I8(_) => ValueType::I8,
|
||||
Self::U16(_) => ValueType::U16,
|
||||
Self::I16(_) => ValueType::I16,
|
||||
Self::U32(_) => ValueType::U32,
|
||||
Self::I32(_) => ValueType::I32,
|
||||
Self::F32(_) => ValueType::F32,
|
||||
Self::Bool(_) => ValueType::Bool,
|
||||
Self::String(_) => ValueType::String,
|
||||
Self::Array(_) => ValueType::Array,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_u8(&self) -> Result<u8> {
|
||||
match self {
|
||||
Self::U8(v) => Ok(*v),
|
||||
@ -227,6 +242,41 @@ impl Value {
|
||||
};
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
fn write<W: std::io::Write>(&self, w: &mut W) -> Result<()> {
|
||||
match self {
|
||||
&Self::U8(v) => w.write_u8(v)?,
|
||||
&Self::I8(v) => w.write_i8(v)?,
|
||||
&Self::U16(v) => w.write_u16::<LittleEndian>(v)?,
|
||||
&Self::I16(v) => w.write_i16::<LittleEndian>(v)?,
|
||||
&Self::U32(v) => w.write_u32::<LittleEndian>(v)?,
|
||||
&Self::I32(v) => w.write_i32::<LittleEndian>(v)?,
|
||||
&Self::F32(v) => w.write_f32::<LittleEndian>(v)?,
|
||||
&Self::Bool(v) => w.write_u8(u8::from(v))?,
|
||||
Self::String(v) => write_string(w, v.as_str())?,
|
||||
Self::Array(v) => {
|
||||
// The `Value` type does not enforce that all the values in an Array have the same
|
||||
// type.
|
||||
let value_type = if v.is_empty() {
|
||||
// Doesn't matter, the array is empty.
|
||||
ValueType::U32
|
||||
} else {
|
||||
let value_type: std::collections::HashSet<_> =
|
||||
v.iter().map(|elem| elem.value_type()).collect();
|
||||
if value_type.len() != 1 {
|
||||
crate::bail!("multiple value-types in the same array {value_type:?}")
|
||||
}
|
||||
value_type.into_iter().next().unwrap()
|
||||
};
|
||||
w.write_u32::<LittleEndian>(value_type.to_u32())?;
|
||||
w.write_u32::<LittleEndian>(v.len() as u32)?;
|
||||
for elem in v.iter() {
|
||||
elem.write(w)?
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl ValueType {
|
||||
@ -246,6 +296,21 @@ impl ValueType {
|
||||
};
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
fn to_u32(self) -> u32 {
|
||||
match self {
|
||||
Self::U8 => 0,
|
||||
Self::I8 => 1,
|
||||
Self::U16 => 2,
|
||||
Self::I16 => 3,
|
||||
Self::U32 => 4,
|
||||
Self::I32 => 5,
|
||||
Self::F32 => 6,
|
||||
Self::Bool => 7,
|
||||
Self::String => 8,
|
||||
Self::Array => 9,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Content {
|
||||
@ -312,3 +377,60 @@ impl Content {
|
||||
tensor_info.read(reader, self.tensor_data_offset)
|
||||
}
|
||||
}
|
||||
|
||||
fn write_string<W: std::io::Write>(w: &mut W, str: &str) -> Result<()> {
|
||||
let bytes = str.as_bytes();
|
||||
w.write_u32::<LittleEndian>(bytes.len() as u32)?;
|
||||
w.write_all(bytes)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write<W: std::io::Seek + std::io::Write>(
|
||||
w: &mut W,
|
||||
metadata: &[(&str, &Value)],
|
||||
tensors: &[(&str, &QTensor)],
|
||||
) -> Result<()> {
|
||||
w.write_u32::<LittleEndian>(0x46554747)?;
|
||||
w.write_u32::<LittleEndian>(1)?; // version 1.
|
||||
w.write_u32::<LittleEndian>(tensors.len() as u32)?;
|
||||
for (name, value) in metadata.iter() {
|
||||
write_string(w, name)?;
|
||||
w.write_u32::<LittleEndian>(value.value_type().to_u32())?;
|
||||
value.write(w)?;
|
||||
}
|
||||
let mut offset = 0usize;
|
||||
let mut offsets = Vec::with_capacity(tensors.len());
|
||||
for (name, tensor) in tensors.iter() {
|
||||
write_string(w, name)?;
|
||||
let dims = tensor.shape().dims();
|
||||
w.write_u32::<LittleEndian>(dims.len() as u32)?;
|
||||
for &dim in dims.iter().rev() {
|
||||
w.write_u32::<LittleEndian>(dim as u32)?;
|
||||
}
|
||||
w.write_u32::<LittleEndian>(tensor.dtype().to_u32())?;
|
||||
w.write_u64::<LittleEndian>(offset as u64)?;
|
||||
offsets.push(offset);
|
||||
let size_in_bytes = tensor.storage_size_in_bytes();
|
||||
let padding = 31 - (31 + size_in_bytes) % 32;
|
||||
offset += size_in_bytes + padding;
|
||||
}
|
||||
let pos = w.stream_position()? as usize;
|
||||
let padding = 31 - (31 + pos) % 32;
|
||||
w.write_all(&vec![0u8; padding])?;
|
||||
let tensor_start_pos = w.stream_position()? as usize;
|
||||
for (offset, (_name, tensor)) in offsets.iter().zip(tensors.iter()) {
|
||||
let pos = w.stream_position()? as usize;
|
||||
if tensor_start_pos + offset != pos {
|
||||
crate::bail!(
|
||||
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
|
||||
)
|
||||
}
|
||||
let data_ptr = tensor.as_ptr();
|
||||
let size_in_bytes = tensor.storage_size_in_bytes();
|
||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||
w.write_all(data)?;
|
||||
let padding = 31 - (31 + size_in_bytes) % 32;
|
||||
w.write_all(&vec![0u8; padding])?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ pub struct QTensor {
|
||||
shape: Shape,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum GgmlDType {
|
||||
F32,
|
||||
F16,
|
||||
@ -56,6 +56,25 @@ impl GgmlDType {
|
||||
Ok(dtype)
|
||||
}
|
||||
|
||||
pub(crate) fn to_u32(self) -> u32 {
|
||||
match self {
|
||||
Self::F32 => 0,
|
||||
Self::F16 => 1,
|
||||
Self::Q4_0 => 2,
|
||||
Self::Q4_1 => 3,
|
||||
Self::Q5_0 => 6,
|
||||
Self::Q5_1 => 7,
|
||||
Self::Q8_0 => 8,
|
||||
Self::Q8_1 => 9,
|
||||
Self::Q2K => 10,
|
||||
Self::Q3K => 11,
|
||||
Self::Q4K => 12,
|
||||
Self::Q5K => 13,
|
||||
Self::Q6K => 14,
|
||||
Self::Q8K => 15,
|
||||
}
|
||||
}
|
||||
|
||||
/// The type size for blocks in bytes.
|
||||
pub fn type_size(&self) -> usize {
|
||||
use k_quants::*;
|
||||
@ -99,6 +118,8 @@ pub trait QuantizedType: Send + Sync {
|
||||
fn dtype(&self) -> GgmlDType;
|
||||
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()>;
|
||||
fn storage_size_in_bytes(&self) -> usize;
|
||||
fn as_ptr(&self) -> *const u8;
|
||||
}
|
||||
|
||||
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||
@ -113,6 +134,14 @@ impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()> {
|
||||
T::to_float(self.as_slice(), ys)
|
||||
}
|
||||
|
||||
fn storage_size_in_bytes(&self) -> usize {
|
||||
self.len() * std::mem::size_of::<T>()
|
||||
}
|
||||
|
||||
fn as_ptr(&self) -> *const u8 {
|
||||
self.as_ptr() as *const u8
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for QTensor {
|
||||
@ -186,6 +215,14 @@ impl QTensor {
|
||||
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
|
||||
self.data.matmul_t(mkn, lhs, dst)
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
self.data.storage_size_in_bytes()
|
||||
}
|
||||
|
||||
pub fn as_ptr(&self) -> *const u8 {
|
||||
self.data.as_ptr()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
Reference in New Issue
Block a user