Add quantized tensors. (#458)

* Add quantized tensors.

* Implement the debug trait for QTensor.

* Add the QMatMul custom op.
This commit is contained in:
Laurent Mazare
2023-08-15 22:45:53 +01:00
committed by GitHub
parent b8263aa15c
commit ca449f9ee1
3 changed files with 140 additions and 108 deletions

View File

@ -1,7 +1,7 @@
//! Support for the GGML file format.
use super::{k_quants, GgmlDType};
use crate::{DType, Device, Result, Tensor};
use crate::Result;
use byteorder::{LittleEndian, ReadBytesExt};
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
@ -116,121 +116,47 @@ impl Vocab {
}
}
fn dequantize_and_create_tensor<T: super::GgmlType>(
fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
raw_data: &[u8],
tensor_elems: usize,
size_in_bytes: usize,
dims: Vec<usize>,
device: &Device,
) -> Result<Tensor> {
let mut f32_data = vec![0f32; tensor_elems];
) -> Result<super::QTensor> {
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
let raw_data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
T::to_float(raw_data, &mut f32_data)?;
Tensor::from_vec(f32_data, dims, device)
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
Ok(super::QTensor::new(data.to_vec(), dims))
}
/// Creates a [Tensor] from a raw GGML tensor.
pub fn tensor_from_ggml(
pub fn qtensor_from_ggml(
ggml_dtype: GgmlDType,
raw_data: &[u8],
dims: Vec<usize>,
dtype: DType,
device: &Device,
) -> Result<Tensor> {
) -> Result<super::QTensor> {
let tensor_elems = dims.iter().product::<usize>();
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
let tensor = match ggml_dtype {
GgmlDType::F32 => Tensor::from_raw_buffer(raw_data, DType::F32, &dims, device),
GgmlDType::F16 => Tensor::from_raw_buffer(raw_data, DType::F16, &dims, device),
GgmlDType::Q4_0 => dequantize_and_create_tensor::<k_quants::BlockQ4_0>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
),
GgmlDType::Q4_1 => dequantize_and_create_tensor::<k_quants::BlockQ4_1>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
),
GgmlDType::Q5_0 => dequantize_and_create_tensor::<k_quants::BlockQ5_0>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
),
GgmlDType::Q5_1 => dequantize_and_create_tensor::<k_quants::BlockQ5_1>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
),
GgmlDType::Q8_0 => dequantize_and_create_tensor::<k_quants::BlockQ8_0>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
),
GgmlDType::Q2K => dequantize_and_create_tensor::<k_quants::BlockQ2K>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
),
GgmlDType::Q3K => dequantize_and_create_tensor::<k_quants::BlockQ3K>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
),
GgmlDType::Q4K => dequantize_and_create_tensor::<k_quants::BlockQ4K>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
),
GgmlDType::Q5K => dequantize_and_create_tensor::<k_quants::BlockQ5K>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
),
GgmlDType::Q6K => dequantize_and_create_tensor::<k_quants::BlockQ6K>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
),
_ => crate::bail!("quantized type {dtype:?} is not supported yet"),
}?;
//We only have ggml-quant to f32 conversions, meaning we have to convert to the desired type
if tensor.dtype() != dtype {
tensor.to_dtype(dtype)
} else {
Ok(tensor)
match ggml_dtype {
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
}
}
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
reader: &mut R,
magic: VersionedMagic,
dtype: DType,
device: &Device,
) -> Result<(String, Tensor)> {
) -> Result<(String, super::QTensor)> {
let n_dims = reader.read_u32::<LittleEndian>()?;
let name_len = reader.read_u32::<LittleEndian>()?;
let ggml_dtype = reader.read_u32::<LittleEndian>()?;
@ -252,26 +178,21 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
// TODO: Mmap version to avoid copying the data around?
let mut raw_data = vec![0u8; size_in_bytes];
reader.read_exact(&mut raw_data)?;
match tensor_from_ggml(ggml_dtype, &raw_data, dims, dtype, device) {
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
Ok(tensor) => Ok((name, tensor)),
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
}
}
#[derive(Debug)]
pub struct Content {
pub magic: VersionedMagic,
pub hparams: HParams,
pub vocab: Vocab,
pub tensors: Vec<(String, Tensor)>,
pub tensors: Vec<(String, super::QTensor)>,
}
impl Content {
pub fn read<R: std::io::Seek + std::io::Read>(
reader: &mut R,
dtype: DType,
device: &Device,
) -> Result<Content> {
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
reader.seek(std::io::SeekFrom::Start(0))?;
@ -281,7 +202,7 @@ impl Content {
let mut tensors = vec![];
while reader.stream_position()? != last_position {
let (name, tensor) = read_one_tensor(reader, magic, dtype, device)?;
let (name, tensor) = read_one_tensor(reader, magic)?;
tensors.push((name, tensor))
}
Ok(Self {