Add a function to write gguf files. (#585)

* Add a function to write gguf files.

* More GGUF file writing.

* Write the tensor data in GGUF files.
This commit is contained in:
Laurent Mazare
2023-08-24 17:03:06 +01:00
committed by GitHub
parent a87c6f7652
commit c265ac50fa
2 changed files with 163 additions and 4 deletions

View File

@ -4,7 +4,7 @@
use super::{GgmlDType, QTensor};
use crate::Result;
use byteorder::{LittleEndian, ReadBytesExt};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap;
pub const DEFAULT_ALIGNMENT: u64 = 32;
@ -19,7 +19,7 @@ impl TryFrom<u32> for Magic {
fn try_from(value: u32) -> Result<Self> {
let magic = match value {
0x46554747 | 0x47475546 => Self::Gguf,
_ => crate::bail!("unknown magic {value:08x}"),
_ => crate::bail!("unknown magic 0x{value:08x}"),
};
Ok(magic)
}
@ -86,7 +86,7 @@ fn read_string<R: std::io::Read>(reader: &mut R) -> Result<String> {
Ok(String::from_utf8_lossy(&v).into_owned())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ValueType {
// The value is a 8-bit unsigned integer.
U8,
@ -129,6 +129,21 @@ pub enum Value {
}
impl Value {
pub fn value_type(&self) -> ValueType {
match self {
Self::U8(_) => ValueType::U8,
Self::I8(_) => ValueType::I8,
Self::U16(_) => ValueType::U16,
Self::I16(_) => ValueType::I16,
Self::U32(_) => ValueType::U32,
Self::I32(_) => ValueType::I32,
Self::F32(_) => ValueType::F32,
Self::Bool(_) => ValueType::Bool,
Self::String(_) => ValueType::String,
Self::Array(_) => ValueType::Array,
}
}
pub fn to_u8(&self) -> Result<u8> {
match self {
Self::U8(v) => Ok(*v),
@ -227,6 +242,41 @@ impl Value {
};
Ok(v)
}
fn write<W: std::io::Write>(&self, w: &mut W) -> Result<()> {
match self {
&Self::U8(v) => w.write_u8(v)?,
&Self::I8(v) => w.write_i8(v)?,
&Self::U16(v) => w.write_u16::<LittleEndian>(v)?,
&Self::I16(v) => w.write_i16::<LittleEndian>(v)?,
&Self::U32(v) => w.write_u32::<LittleEndian>(v)?,
&Self::I32(v) => w.write_i32::<LittleEndian>(v)?,
&Self::F32(v) => w.write_f32::<LittleEndian>(v)?,
&Self::Bool(v) => w.write_u8(u8::from(v))?,
Self::String(v) => write_string(w, v.as_str())?,
Self::Array(v) => {
// The `Value` type does not enforce that all the values in an Array have the same
// type.
let value_type = if v.is_empty() {
// Doesn't matter, the array is empty.
ValueType::U32
} else {
let value_type: std::collections::HashSet<_> =
v.iter().map(|elem| elem.value_type()).collect();
if value_type.len() != 1 {
crate::bail!("multiple value-types in the same array {value_type:?}")
}
value_type.into_iter().next().unwrap()
};
w.write_u32::<LittleEndian>(value_type.to_u32())?;
w.write_u32::<LittleEndian>(v.len() as u32)?;
for elem in v.iter() {
elem.write(w)?
}
}
}
Ok(())
}
}
impl ValueType {
@ -246,6 +296,21 @@ impl ValueType {
};
Ok(v)
}
fn to_u32(self) -> u32 {
match self {
Self::U8 => 0,
Self::I8 => 1,
Self::U16 => 2,
Self::I16 => 3,
Self::U32 => 4,
Self::I32 => 5,
Self::F32 => 6,
Self::Bool => 7,
Self::String => 8,
Self::Array => 9,
}
}
}
impl Content {
@ -312,3 +377,60 @@ impl Content {
tensor_info.read(reader, self.tensor_data_offset)
}
}
fn write_string<W: std::io::Write>(w: &mut W, str: &str) -> Result<()> {
let bytes = str.as_bytes();
w.write_u32::<LittleEndian>(bytes.len() as u32)?;
w.write_all(bytes)?;
Ok(())
}
pub fn write<W: std::io::Seek + std::io::Write>(
w: &mut W,
metadata: &[(&str, &Value)],
tensors: &[(&str, &QTensor)],
) -> Result<()> {
w.write_u32::<LittleEndian>(0x46554747)?;
w.write_u32::<LittleEndian>(1)?; // version 1.
w.write_u32::<LittleEndian>(tensors.len() as u32)?;
for (name, value) in metadata.iter() {
write_string(w, name)?;
w.write_u32::<LittleEndian>(value.value_type().to_u32())?;
value.write(w)?;
}
let mut offset = 0usize;
let mut offsets = Vec::with_capacity(tensors.len());
for (name, tensor) in tensors.iter() {
write_string(w, name)?;
let dims = tensor.shape().dims();
w.write_u32::<LittleEndian>(dims.len() as u32)?;
for &dim in dims.iter().rev() {
w.write_u32::<LittleEndian>(dim as u32)?;
}
w.write_u32::<LittleEndian>(tensor.dtype().to_u32())?;
w.write_u64::<LittleEndian>(offset as u64)?;
offsets.push(offset);
let size_in_bytes = tensor.storage_size_in_bytes();
let padding = 31 - (31 + size_in_bytes) % 32;
offset += size_in_bytes + padding;
}
let pos = w.stream_position()? as usize;
let padding = 31 - (31 + pos) % 32;
w.write_all(&vec![0u8; padding])?;
let tensor_start_pos = w.stream_position()? as usize;
for (offset, (_name, tensor)) in offsets.iter().zip(tensors.iter()) {
let pos = w.stream_position()? as usize;
if tensor_start_pos + offset != pos {
crate::bail!(
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
)
}
let data_ptr = tensor.as_ptr();
let size_in_bytes = tensor.storage_size_in_bytes();
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
w.write_all(data)?;
let padding = 31 - (31 + size_in_bytes) % 32;
w.write_all(&vec![0u8; padding])?;
}
Ok(())
}