mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Support some ggml quantized types (#314)
* Add the quantized types for GGML loading. * Support quantization for Q2K. * More quantization support. * Fix some clippy lints.
This commit is contained in:
@ -1,12 +1,207 @@
|
|||||||
//! Support for the GGML file format.
|
//! Support for the GGML file format.
|
||||||
|
|
||||||
use crate::Result;
|
use crate::{DType, Device, Result, Tensor};
|
||||||
use byteorder::{LittleEndian, ReadBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt};
|
||||||
|
use half::f16;
|
||||||
|
|
||||||
// Default to QK_K 256 rather than 64.
|
// Default to QK_K 256 rather than 64.
|
||||||
pub const QK_K: usize = 256;
|
pub const QK_K: usize = 256;
|
||||||
pub const K_SCALE_SIZE: usize = 12;
|
pub const K_SCALE_SIZE: usize = 12;
|
||||||
|
|
||||||
|
pub const QK4_0: usize = 32;
|
||||||
|
pub const QK4_1: usize = 32;
|
||||||
|
pub const QK5_0: usize = 32;
|
||||||
|
pub const QK5_1: usize = 32;
|
||||||
|
pub const QK8_0: usize = 32;
|
||||||
|
pub const QK8_1: usize = 32;
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ4_0 {
|
||||||
|
d: f16,
|
||||||
|
qs: [u8; QK4_0 / 2],
|
||||||
|
}
|
||||||
|
// Hacky static_assert
|
||||||
|
const _: [u8; 18] = [0; std::mem::size_of::<BlockQ4_0>()];
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ4_1 {
|
||||||
|
d: f16,
|
||||||
|
m: f16,
|
||||||
|
qs: [u8; QK4_1 / 2],
|
||||||
|
}
|
||||||
|
const _: [u8; 20] = [0; std::mem::size_of::<BlockQ4_1>()];
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ5_0 {
|
||||||
|
d: f16,
|
||||||
|
qh: [u8; 4],
|
||||||
|
qs: [u8; QK5_0 / 2],
|
||||||
|
}
|
||||||
|
const _: [u8; 22] = [0; std::mem::size_of::<BlockQ5_0>()];
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ5_1 {
|
||||||
|
d: f16,
|
||||||
|
m: f16,
|
||||||
|
qh: [u8; 4],
|
||||||
|
qs: [u8; QK5_1 / 2],
|
||||||
|
}
|
||||||
|
const _: [u8; 24] = [0; std::mem::size_of::<BlockQ5_1>()];
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ8_0 {
|
||||||
|
d: f16,
|
||||||
|
qs: [u8; QK8_0],
|
||||||
|
}
|
||||||
|
const _: [u8; 34] = [0; std::mem::size_of::<BlockQ8_0>()];
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ8_1 {
|
||||||
|
d: f16,
|
||||||
|
s: f16,
|
||||||
|
qs: [u8; QK8_1],
|
||||||
|
}
|
||||||
|
const _: [u8; 36] = [0; std::mem::size_of::<BlockQ8_1>()];
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ2K {
|
||||||
|
scales: [u8; QK_K / 16],
|
||||||
|
qs: [u8; QK_K / 4],
|
||||||
|
d: f16,
|
||||||
|
dmin: f16,
|
||||||
|
}
|
||||||
|
const _: [u8; QK_K / 16 + QK_K / 4 + 2 * 2] = [0; std::mem::size_of::<BlockQ2K>()];
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ3K {
|
||||||
|
hmask: [u8; QK_K / 8],
|
||||||
|
qs: [u8; QK_K / 4],
|
||||||
|
scales: [u8; 12],
|
||||||
|
d: f16,
|
||||||
|
}
|
||||||
|
const _: [u8; QK_K / 8 + QK_K / 4 + 12 + 2] = [0; std::mem::size_of::<BlockQ3K>()];
|
||||||
|
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ4K {
|
||||||
|
d: f16,
|
||||||
|
dmin: f16,
|
||||||
|
scales: [u8; K_SCALE_SIZE],
|
||||||
|
qs: [u8; QK_K / 2],
|
||||||
|
}
|
||||||
|
const _: [u8; QK_K / 2 + K_SCALE_SIZE + 2 * 2] = [0; std::mem::size_of::<BlockQ4K>()];
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ5K {
|
||||||
|
d: f16,
|
||||||
|
dmin: f16,
|
||||||
|
scales: [u8; K_SCALE_SIZE],
|
||||||
|
qh: [u8; QK_K / 8],
|
||||||
|
qs: [u8; QK_K / 2],
|
||||||
|
}
|
||||||
|
const _: [u8; QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE] = [0; std::mem::size_of::<BlockQ5K>()];
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
struct BlockQ6K {
|
||||||
|
ql: [u8; QK_K / 2],
|
||||||
|
qh: [u8; QK_K / 4],
|
||||||
|
scales: [i8; QK_K / 16],
|
||||||
|
d: f16,
|
||||||
|
}
|
||||||
|
const _: [u8; 3 * QK_K / 4 + QK_K / 16 + 2] = [0; std::mem::size_of::<BlockQ6K>()];
|
||||||
|
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354
|
||||||
|
fn dequantize_row_q2k(xs: &[BlockQ2K], ys: &mut [f32]) -> Result<()> {
|
||||||
|
let k = ys.len();
|
||||||
|
if k % QK_K != 0 {
|
||||||
|
crate::bail!("dequantize_row_q2k: {k} is not divisible by {QK_K}")
|
||||||
|
}
|
||||||
|
let mut ys_index = 0;
|
||||||
|
for x in xs {
|
||||||
|
let d = x.d.to_f32();
|
||||||
|
let min = x.dmin.to_f32();
|
||||||
|
let q = &x.qs;
|
||||||
|
|
||||||
|
let mut is = 0;
|
||||||
|
for n in (0..QK_K).step_by(128) {
|
||||||
|
// Step by 32 over q.
|
||||||
|
let q = &q[n / 4..];
|
||||||
|
let mut shift = 0;
|
||||||
|
for _j in 0..4 {
|
||||||
|
let sc = x.scales[is];
|
||||||
|
is += 1;
|
||||||
|
let dl = d * (sc & 0xF) as f32;
|
||||||
|
let ml = min * (sc >> 4) as f32;
|
||||||
|
for q in &q[..16] {
|
||||||
|
let y = dl * ((q >> shift) & 3) as i8 as f32 - ml;
|
||||||
|
ys[ys_index] = y;
|
||||||
|
ys_index += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let sc = x.scales[is];
|
||||||
|
is += 1;
|
||||||
|
let dl = d * (sc & 0xF) as f32;
|
||||||
|
let ml = min * (sc >> 4) as f32;
|
||||||
|
for q in &q[16..32] {
|
||||||
|
let y = dl * ((q >> shift) & 3) as i8 as f32 - ml;
|
||||||
|
ys[ys_index] = y;
|
||||||
|
ys_index += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
shift += 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
|
||||||
|
if j < 4 {
|
||||||
|
let d = q[j] & 63;
|
||||||
|
let m = q[j + 4] & 63;
|
||||||
|
(d, m)
|
||||||
|
} else {
|
||||||
|
let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
|
||||||
|
let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
|
||||||
|
(d, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735
|
||||||
|
fn dequantize_row_q4k(xs: &[BlockQ4K], ys: &mut [f32]) -> Result<()> {
|
||||||
|
let k = ys.len();
|
||||||
|
if k % QK_K != 0 {
|
||||||
|
crate::bail!("dequantize_row_q2k: {k} is not divisible by {QK_K}")
|
||||||
|
}
|
||||||
|
let mut ys_index = 0;
|
||||||
|
for x in xs.iter() {
|
||||||
|
let d = x.d.to_f32();
|
||||||
|
let min = x.dmin.to_f32();
|
||||||
|
let q = &x.qs;
|
||||||
|
let mut is = 0;
|
||||||
|
for j in (0..QK_K).step_by(64) {
|
||||||
|
let q = &q[j / 2..j / 2 + 32];
|
||||||
|
let (sc, m) = get_scale_min_k4(is, &x.scales);
|
||||||
|
let d1 = d * sc as f32;
|
||||||
|
let m1 = min * m as f32;
|
||||||
|
let (sc, m) = get_scale_min_k4(is + 1, &x.scales);
|
||||||
|
let d2 = d * sc as f32;
|
||||||
|
let m2 = min * m as f32;
|
||||||
|
for q in q {
|
||||||
|
let y = d1 * (q & 0xF) as f32 - m1;
|
||||||
|
ys[ys_index] = y;
|
||||||
|
ys_index += 1;
|
||||||
|
}
|
||||||
|
for q in q {
|
||||||
|
let y = d2 * (q >> 4) as f32 - m2;
|
||||||
|
ys[ys_index] = y;
|
||||||
|
ys_index += 1;
|
||||||
|
}
|
||||||
|
is += 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
|
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
enum Magic {
|
enum Magic {
|
||||||
@ -161,19 +356,18 @@ impl GgmlDType {
|
|||||||
match self {
|
match self {
|
||||||
Self::F32 => 4,
|
Self::F32 => 4,
|
||||||
Self::F16 => 2,
|
Self::F16 => 2,
|
||||||
Self::Q4_0 => 18,
|
Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
|
||||||
Self::Q4_1 => 20,
|
Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
|
||||||
Self::Q5_0 => 22,
|
Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
|
||||||
Self::Q5_1 => 24,
|
Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932
|
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932
|
||||||
Self::Q8_0 => 34,
|
Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
|
||||||
Self::Q8_1 => 36,
|
Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
|
||||||
Self::Q2K => QK_K / 16 + QK_K / 4 + 2 * 2,
|
Self::Q2K => std::mem::size_of::<BlockQ2K>(),
|
||||||
Self::Q3K => QK_K / 8 + QK_K / 4 + 12 + 2,
|
Self::Q3K => std::mem::size_of::<BlockQ3K>(),
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82
|
Self::Q4K => std::mem::size_of::<BlockQ4K>(),
|
||||||
Self::Q4K => QK_K / 2 + K_SCALE_SIZE + 2 * 2,
|
Self::Q5K => std::mem::size_of::<BlockQ5K>(),
|
||||||
Self::Q5K => QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE,
|
Self::Q6K => std::mem::size_of::<BlockQ6K>(),
|
||||||
Self::Q6K => 3 * QK_K / 4 + QK_K / 16 + 2,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -181,12 +375,12 @@ impl GgmlDType {
|
|||||||
match self {
|
match self {
|
||||||
Self::F32 => 1,
|
Self::F32 => 1,
|
||||||
Self::F16 => 1,
|
Self::F16 => 1,
|
||||||
Self::Q4_0 => 32,
|
Self::Q4_0 => QK4_0,
|
||||||
Self::Q4_1 => 32,
|
Self::Q4_1 => QK4_1,
|
||||||
Self::Q5_0 => 32,
|
Self::Q5_0 => QK5_0,
|
||||||
Self::Q5_1 => 32,
|
Self::Q5_1 => QK5_1,
|
||||||
Self::Q8_0 => 32,
|
Self::Q8_0 => QK8_0,
|
||||||
Self::Q8_1 => 32,
|
Self::Q8_1 => QK8_1,
|
||||||
Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K => QK_K,
|
Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K => QK_K,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -197,41 +391,84 @@ pub struct Content {
|
|||||||
pub magic: VersionedMagic,
|
pub magic: VersionedMagic,
|
||||||
pub hparams: HParams,
|
pub hparams: HParams,
|
||||||
pub vocab: Vocab,
|
pub vocab: Vocab,
|
||||||
|
pub tensors: Vec<(String, Tensor)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||||
|
reader: &mut R,
|
||||||
|
magic: VersionedMagic,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<(String, Tensor)> {
|
||||||
|
let n_dims = reader.read_u32::<LittleEndian>()?;
|
||||||
|
let name_len = reader.read_u32::<LittleEndian>()?;
|
||||||
|
let dtype = reader.read_u32::<LittleEndian>()?;
|
||||||
|
let dtype = GgmlDType::from_u32(dtype)?;
|
||||||
|
let mut dims = vec![0u32; n_dims as usize];
|
||||||
|
reader.read_u32_into::<LittleEndian>(&mut dims)?;
|
||||||
|
let mut name = vec![0u8; name_len as usize];
|
||||||
|
reader.read_exact(&mut name)?;
|
||||||
|
let name = String::from_utf8_lossy(&name).into_owned();
|
||||||
|
|
||||||
|
if magic.align32() {
|
||||||
|
let pos = reader.stream_position()?;
|
||||||
|
reader.seek(std::io::SeekFrom::Current(((32 - pos % 32) % 32) as i64))?;
|
||||||
|
}
|
||||||
|
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
|
||||||
|
let tensor_elems = dims.iter().product::<usize>();
|
||||||
|
let size_in_bytes = tensor_elems * dtype.type_size() / dtype.blck_size();
|
||||||
|
println!("{name} {dtype:?} {dims:?}");
|
||||||
|
// TODO: Mmap version to avoid copying the data around?
|
||||||
|
let mut raw_data = vec![0u8; size_in_bytes];
|
||||||
|
reader.read_exact(&mut raw_data)?;
|
||||||
|
let tensor = match dtype {
|
||||||
|
GgmlDType::F32 => Tensor::from_raw_buffer(&raw_data, DType::F32, &dims, device)?,
|
||||||
|
GgmlDType::F16 => Tensor::from_raw_buffer(&raw_data, DType::F16, &dims, device)?,
|
||||||
|
GgmlDType::Q2K => {
|
||||||
|
let mut f32_data = vec![0f32; tensor_elems];
|
||||||
|
let raw_data_ptr = raw_data.as_ptr();
|
||||||
|
let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ2K>();
|
||||||
|
let raw_data =
|
||||||
|
unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ2K, n_blocks) };
|
||||||
|
dequantize_row_q2k(raw_data, &mut f32_data)?;
|
||||||
|
// Maybe we should use bf16 instead?
|
||||||
|
Tensor::from_vec(f32_data, dims, device)?
|
||||||
|
}
|
||||||
|
GgmlDType::Q4K => {
|
||||||
|
let mut f32_data = vec![0f32; tensor_elems];
|
||||||
|
let raw_data_ptr = raw_data.as_ptr();
|
||||||
|
let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ4K>();
|
||||||
|
let raw_data =
|
||||||
|
unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ4K, n_blocks) };
|
||||||
|
dequantize_row_q4k(raw_data, &mut f32_data)?;
|
||||||
|
Tensor::from_vec(f32_data, dims, device)?
|
||||||
|
}
|
||||||
|
_ => crate::bail!("quantized type {dtype:?} used in {name} is not supported yet"),
|
||||||
|
};
|
||||||
|
Ok((name, tensor))
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Content {
|
impl Content {
|
||||||
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
|
pub fn read<R: std::io::Seek + std::io::Read>(
|
||||||
|
reader: &mut R,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Content> {
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
||||||
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
||||||
reader.seek(std::io::SeekFrom::Start(0))?;
|
reader.seek(std::io::SeekFrom::Start(0))?;
|
||||||
let magic = VersionedMagic::read(reader)?;
|
let magic = VersionedMagic::read(reader)?;
|
||||||
let hparams = HParams::read(reader)?;
|
let hparams = HParams::read(reader)?;
|
||||||
let vocab = Vocab::read(reader, hparams.n_vocab as usize)?;
|
let vocab = Vocab::read(reader, hparams.n_vocab as usize)?;
|
||||||
|
let mut tensors = vec![];
|
||||||
|
|
||||||
while reader.stream_position()? != last_position {
|
while reader.stream_position()? != last_position {
|
||||||
let n_dims = reader.read_u32::<LittleEndian>()?;
|
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
||||||
let name_len = reader.read_u32::<LittleEndian>()?;
|
tensors.push((name, tensor))
|
||||||
let dtype = reader.read_u32::<LittleEndian>()?;
|
|
||||||
let dtype = GgmlDType::from_u32(dtype)?;
|
|
||||||
let mut dims = vec![0u32; n_dims as usize];
|
|
||||||
reader.read_u32_into::<LittleEndian>(&mut dims)?;
|
|
||||||
let mut name = vec![0u8; name_len as usize];
|
|
||||||
reader.read_exact(&mut name)?;
|
|
||||||
let name = String::from_utf8_lossy(&name).into_owned();
|
|
||||||
|
|
||||||
if magic.align32() {
|
|
||||||
let pos = reader.stream_position()?;
|
|
||||||
reader.seek(std::io::SeekFrom::Current(((32 - pos % 32) % 32) as i64))?;
|
|
||||||
}
|
|
||||||
let tensor_elems = dims.iter().map(|&u| u as usize).product::<usize>();
|
|
||||||
let tensor_size = tensor_elems * dtype.type_size() / dtype.blck_size();
|
|
||||||
println!("{name} {dtype:?} {dims:?}");
|
|
||||||
reader.seek(std::io::SeekFrom::Current(tensor_size as i64))?;
|
|
||||||
}
|
}
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
magic,
|
magic,
|
||||||
hparams,
|
hparams,
|
||||||
vocab,
|
vocab,
|
||||||
|
tensors,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user