mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Basic quantization support (#453)
* Add a vecdot trait. * Start implementing mul_mat. * Add to the mul mat implementation. * Add q8_0 quantization. * Implement the GgmlType trait for all types. * Add the missing block. * Add a TODO.
This commit is contained in:
@ -15,15 +15,27 @@ pub const QK5_1: usize = 32;
|
||||
pub const QK8_0: usize = 32;
|
||||
pub const QK8_1: usize = 32;
|
||||
|
||||
pub trait GgmlType: Sized {
|
||||
const DTYPE: GgmlDType;
|
||||
const BLCK_SIZE: usize;
|
||||
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()>;
|
||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()>;
|
||||
|
||||
type VecDotType: GgmlType;
|
||||
// Dot product used as a building block for quantized mat-mul.
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ4_0 {
|
||||
pub struct BlockQ4_0 {
|
||||
d: f16,
|
||||
qs: [u8; QK4_0 / 2],
|
||||
}
|
||||
const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == 18);
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ4_1 {
|
||||
pub struct BlockQ4_1 {
|
||||
d: f16,
|
||||
m: f16,
|
||||
qs: [u8; QK4_1 / 2],
|
||||
@ -31,7 +43,7 @@ struct BlockQ4_1 {
|
||||
const _: () = assert!(std::mem::size_of::<BlockQ4_1>() == 20);
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ5_0 {
|
||||
pub struct BlockQ5_0 {
|
||||
d: f16,
|
||||
qh: [u8; 4],
|
||||
qs: [u8; QK5_0 / 2],
|
||||
@ -39,7 +51,7 @@ struct BlockQ5_0 {
|
||||
const _: () = assert!(std::mem::size_of::<BlockQ5_0>() == 22);
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ5_1 {
|
||||
pub struct BlockQ5_1 {
|
||||
d: f16,
|
||||
m: f16,
|
||||
qh: [u8; 4],
|
||||
@ -48,14 +60,14 @@ struct BlockQ5_1 {
|
||||
const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24);
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ8_0 {
|
||||
pub struct BlockQ8_0 {
|
||||
d: f16,
|
||||
qs: [u8; QK8_0],
|
||||
}
|
||||
const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ8_1 {
|
||||
pub struct BlockQ8_1 {
|
||||
d: f16,
|
||||
s: f16,
|
||||
qs: [u8; QK8_1],
|
||||
@ -63,7 +75,7 @@ struct BlockQ8_1 {
|
||||
const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ2K {
|
||||
pub struct BlockQ2K {
|
||||
scales: [u8; QK_K / 16],
|
||||
qs: [u8; QK_K / 4],
|
||||
d: f16,
|
||||
@ -72,7 +84,7 @@ struct BlockQ2K {
|
||||
const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::<BlockQ2K>());
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ3K {
|
||||
pub struct BlockQ3K {
|
||||
hmask: [u8; QK_K / 8],
|
||||
qs: [u8; QK_K / 4],
|
||||
scales: [u8; 12],
|
||||
@ -82,7 +94,7 @@ const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82
|
||||
#[repr(C)]
|
||||
struct BlockQ4K {
|
||||
pub struct BlockQ4K {
|
||||
d: f16,
|
||||
dmin: f16,
|
||||
scales: [u8; K_SCALE_SIZE],
|
||||
@ -91,7 +103,7 @@ struct BlockQ4K {
|
||||
const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::<BlockQ4K>());
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ5K {
|
||||
pub struct BlockQ5K {
|
||||
d: f16,
|
||||
dmin: f16,
|
||||
scales: [u8; K_SCALE_SIZE],
|
||||
@ -102,7 +114,7 @@ const _: () =
|
||||
assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::<BlockQ5K>());
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ6K {
|
||||
pub struct BlockQ6K {
|
||||
ql: [u8; QK_K / 2],
|
||||
qh: [u8; QK_K / 4],
|
||||
scales: [i8; QK_K / 16],
|
||||
@ -110,30 +122,29 @@ struct BlockQ6K {
|
||||
}
|
||||
const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::<BlockQ6K>());
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1525
|
||||
fn dequantize_row_q4_0(xs: &[BlockQ4_0], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK4_0 != 0 {
|
||||
crate::bail!("dequantize_row_q4_0: {k} is not divisible by {QK4_0}")
|
||||
}
|
||||
|
||||
let nb = k / QK4_0;
|
||||
for i in 0..nb {
|
||||
let d = xs[i].d.to_f32();
|
||||
|
||||
for j in 0..(QK4_0 / 2) {
|
||||
let x0 = (xs[i].qs[j] & 0x0F) as i16 - 8;
|
||||
let x1 = (xs[i].qs[j] >> 4) as i16 - 8;
|
||||
|
||||
ys[i * QK4_0 + j] = (x0 as f32) * d;
|
||||
ys[i * QK4_0 + j + QK4_0 / 2] = (x1 as f32) * d;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
#[repr(C)]
|
||||
pub struct BlockQ8K {
|
||||
d: f32,
|
||||
qs: [i8; QK_K],
|
||||
bsums: [i16; QK_K / 16],
|
||||
}
|
||||
const _: () = assert!(4 + QK_K + QK_K / 16 * 2 == std::mem::size_of::<BlockQ8K>());
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1545
|
||||
fn dequantize_row_q4_1(xs: &[BlockQ4_1], ys: &mut [f32]) -> Result<()> {
|
||||
impl GgmlType for BlockQ4_1 {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q4_1;
|
||||
const BLCK_SIZE: usize = QK4_1;
|
||||
type VecDotType = BlockQ8_1;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1545
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK4_1 != 0 {
|
||||
crate::bail!("dequantize_row_q4_1: {k} is not divisible by {QK4_1}");
|
||||
@ -153,10 +164,24 @@ fn dequantize_row_q4_1(xs: &[BlockQ4_1], ys: &mut [f32]) -> Result<()> {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1566
|
||||
fn dequantize_row_q5_0(xs: &[BlockQ5_0], ys: &mut [f32]) -> Result<()> {
|
||||
impl GgmlType for BlockQ5_0 {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q5_0;
|
||||
const BLCK_SIZE: usize = QK5_0;
|
||||
type VecDotType = BlockQ8_0;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1566
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK5_0 != 0 {
|
||||
crate::bail!("dequantize_row_q5_0: {k} is not divisible by {QK5_0}");
|
||||
@ -179,10 +204,24 @@ fn dequantize_row_q5_0(xs: &[BlockQ5_0], ys: &mut [f32]) -> Result<()> {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1592
|
||||
fn dequantize_row_q5_1(xs: &[BlockQ5_1], ys: &mut [f32]) -> Result<()> {
|
||||
impl GgmlType for BlockQ5_1 {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q5_1;
|
||||
const BLCK_SIZE: usize = QK5_1;
|
||||
type VecDotType = BlockQ8_1;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1592
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK5_1 != 0 {
|
||||
crate::bail!("dequantize_row_q5_1: {k} is not divisible by {QK5_1}");
|
||||
@ -206,30 +245,23 @@ fn dequantize_row_q5_1(xs: &[BlockQ5_1], ys: &mut [f32]) -> Result<()> {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1619
|
||||
fn dequantize_row_q8_0(vx: &[BlockQ8_0], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK8_0 != 0 {
|
||||
crate::bail!("dequantize_row_q8_0: {k} is not divisible by {QK8_0}");
|
||||
impl GgmlType for BlockQ2K {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q2K;
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
let nb = k / QK8_0;
|
||||
let xs: &[BlockQ8_0] = unsafe { std::mem::transmute(vx) };
|
||||
|
||||
for i in 0..nb {
|
||||
let d = xs[i].d.to_f32();
|
||||
|
||||
for j in 0..QK8_0 {
|
||||
ys[i * QK8_0 + j] = xs[i].qs[j] as f32 * d;
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354
|
||||
fn dequantize_row_q2k(xs: &[BlockQ2K], ys: &mut [f32]) -> Result<()> {
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK_K != 0 {
|
||||
crate::bail!("dequantize_row_q2k: {k} is not divisible by {QK_K}")
|
||||
@ -271,6 +303,7 @@ fn dequantize_row_q2k(xs: &[BlockQ2K], ys: &mut [f32]) -> Result<()> {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
|
||||
@ -284,8 +317,21 @@ fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
|
||||
(d, m)
|
||||
}
|
||||
}
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735
|
||||
fn dequantize_row_q4k(xs: &[BlockQ4K], ys: &mut [f32]) -> Result<()> {
|
||||
|
||||
impl GgmlType for BlockQ4K {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q4K;
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK_K != 0 {
|
||||
crate::bail!("dequantize_row_q4k: {k} is not divisible by {QK_K}")
|
||||
@ -318,15 +364,42 @@ fn dequantize_row_q4k(xs: &[BlockQ4K], ys: &mut [f32]) -> Result<()> {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
|
||||
fn dequantize_row_q3k(_xs: &[BlockQ3K], _ys: &mut [f32]) -> Result<()> {
|
||||
impl GgmlType for BlockQ3K {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q3K;
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
|
||||
fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928
|
||||
fn dequantize_row_q5k(xs: &[BlockQ5K], ys: &mut [f32]) -> Result<()> {
|
||||
impl GgmlType for BlockQ5K {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q5K;
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK_K != 0 {
|
||||
crate::bail!("dequantize_row_q5k: {k} is not divisible by {QK_K}")
|
||||
@ -366,10 +439,24 @@ fn dequantize_row_q5k(xs: &[BlockQ5K], ys: &mut [f32]) -> Result<()> {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067
|
||||
fn dequantize_row_q6k(xs: &[BlockQ6K], ys: &mut [f32]) -> Result<()> {
|
||||
impl GgmlType for BlockQ6K {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q6K;
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK_K != 0 {
|
||||
crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}")
|
||||
@ -399,6 +486,234 @@ fn dequantize_row_q6k(xs: &[BlockQ6K], ys: &mut [f32]) -> Result<()> {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl GgmlType for BlockQ8K {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q8K;
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
|
||||
fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl GgmlType for BlockQ4_0 {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q4_0;
|
||||
const BLCK_SIZE: usize = QK4_0;
|
||||
type VecDotType = BlockQ8_0;
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1525
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK4_0 != 0 {
|
||||
crate::bail!("dequantize_row_q4_0: {k} is not divisible by {QK4_0}")
|
||||
}
|
||||
|
||||
let nb = k / QK4_0;
|
||||
for i in 0..nb {
|
||||
let d = xs[i].d.to_f32();
|
||||
|
||||
for j in 0..(QK4_0 / 2) {
|
||||
let x0 = (xs[i].qs[j] & 0x0F) as i16 - 8;
|
||||
let x1 = (xs[i].qs[j] >> 4) as i16 - 8;
|
||||
|
||||
ys[i * QK4_0 + j] = (x0 as f32) * d;
|
||||
ys[i * QK4_0 + j + QK4_0 / 2] = (x1 as f32) * d;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn from_float(_: &[f32], _: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L2361C10-L2361C122
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
let nb = n / qk;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
|
||||
// Generic implementation.
|
||||
let mut sumf = 0f32;
|
||||
for i in 0..nb {
|
||||
let mut sum_i = 0;
|
||||
for j in 0..qk / 2 {
|
||||
let v0 = (xs[i].qs[j] & 0x0F) - 8;
|
||||
let v1 = (xs[i].qs[j] >> 4) - 8;
|
||||
sum_i += v0 * ys[i].qs[j] + v1 * ys[i].qs[j + qk / 2]
|
||||
}
|
||||
sumf += sum_i as f32 * f16::to_f32(xs[i].d) * f16::to_f32(ys[i].d)
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
}
|
||||
|
||||
impl GgmlType for BlockQ8_0 {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q8_0;
|
||||
const BLCK_SIZE: usize = QK8_0;
|
||||
type VecDotType = BlockQ8_0;
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1619
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK8_0 != 0 {
|
||||
crate::bail!("dequantize_row_q8_0: {k} is not divisible by {QK8_0}");
|
||||
}
|
||||
|
||||
let nb = k / QK8_0;
|
||||
|
||||
for i in 0..nb {
|
||||
let d = xs[i].d.to_f32();
|
||||
|
||||
for j in 0..QK8_0 {
|
||||
ys[i * QK8_0 + j] = xs[i].qs[j] as f32 * d;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||
// quantize_row_q8_0
|
||||
let k = ys.len();
|
||||
if k % Self::BLCK_SIZE != 0 {
|
||||
crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE);
|
||||
};
|
||||
let nb = k / Self::BLCK_SIZE;
|
||||
if ys.len() != nb {
|
||||
crate::bail!(
|
||||
"size mismatch {} {} {}",
|
||||
xs.len(),
|
||||
ys.len(),
|
||||
Self::BLCK_SIZE
|
||||
)
|
||||
}
|
||||
for (i, ys) in ys.iter_mut().enumerate() {
|
||||
let mut amax = 0f32;
|
||||
let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
|
||||
for &x in xs.iter() {
|
||||
amax = amax.max(x)
|
||||
}
|
||||
let d = amax / ((1 << 7) - 1) as f32;
|
||||
let id = if d != 0f32 { 1. / d } else { 0. };
|
||||
ys.d = f16::from_f32(d);
|
||||
for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) {
|
||||
*y = f32::round(x * id) as u8
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn vec_dot(_: usize, _: &[Self], _: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl GgmlType for BlockQ8_1 {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q3K;
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8_1;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
|
||||
fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
const BLCK0: usize = 16;
|
||||
const BLCK1: usize = 16;
|
||||
|
||||
// This implementation is in-line with the ggml one and keeps the same variable names.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605
|
||||
pub fn forward_mul_mat<T: GgmlType>(src0: &[T], src1: &[f32], dst: &mut [f32]) -> Result<()> {
|
||||
// TODO: Use the proper sizes here.
|
||||
let (ne00, ne01, ne02, ne03) = (1, 1, 1, 1);
|
||||
let (ne10, ne11, ne12, ne13) = (1, 1, 1, 1);
|
||||
// The strides are in bytes in ggml, however we use the number of elements in candle.
|
||||
let (_, nb1, nb2, nb3) = (1, 1, 1, 1);
|
||||
let (_, nb01, nb02, nb03) = (1, 1, 1, 1);
|
||||
let (_, nb11, nb12, nb13) = (1, 1, 1, 1);
|
||||
|
||||
let nr0 = ne01; // src0 rows
|
||||
let nr1 = ne11 * ne12 * ne13;
|
||||
|
||||
// TODO: Either add multi-threading or remove these bits.
|
||||
let ir010 = 0;
|
||||
let ir011 = nr0;
|
||||
let ir110 = 0;
|
||||
let ir111 = nr1;
|
||||
let r2 = ne12 / ne02;
|
||||
let r3 = ne13 / ne03;
|
||||
|
||||
// TODO: Pre-allocate this.
|
||||
let wdata = &mut [];
|
||||
if ne10 % T::BLCK_SIZE != 0 {
|
||||
crate::bail!(
|
||||
"forward_mul_mat: ne10 {ne10} is not divisible by block size {}",
|
||||
T::BLCK_SIZE
|
||||
)
|
||||
}
|
||||
let row_size = ne10 / T::BLCK_SIZE;
|
||||
for i13 in 0..ne13 {
|
||||
for i12 in 0..ne12 {
|
||||
for i11 in 0..ne11 {
|
||||
let wdata_idx = i11 + i12 * ne11 + i13 * ne11 * ne12;
|
||||
let wdata = &mut wdata[wdata_idx..wdata_idx + row_size];
|
||||
let src1 = &src1[i13 * nb13 + i12 * nb12 + i11 * nb11..];
|
||||
T::VecDotType::from_float(src1, wdata)?
|
||||
}
|
||||
}
|
||||
}
|
||||
for iir1 in (ir110..ir111).step_by(BLCK1) {
|
||||
for iir0 in (ir010..ir011).step_by(BLCK0) {
|
||||
for ir1 in iir1..usize::min(iir1 + BLCK1, ir111) {
|
||||
let i13 = ir1 / (ne12 * ne11);
|
||||
let i12 = (ir1 - i13 * ne12 * ne11) / ne11;
|
||||
let i11 = ir1 - i13 * ne12 * ne11 - i12 * ne11;
|
||||
|
||||
let i03 = i13 / r3;
|
||||
let i02 = i12 / r2;
|
||||
|
||||
let i1 = i11;
|
||||
let i2 = i12;
|
||||
let i3 = i13;
|
||||
|
||||
let src0_row = &src0[i02 * nb02 + i03 * nb03..];
|
||||
let src1_col = &wdata[(i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size..];
|
||||
let dst_col = &mut dst[i1 * nb1 + i2 * nb2 + i3 * nb3..];
|
||||
for ir0 in iir0..usize::min(iir0 + BLCK0, ir011) {
|
||||
let src0_row = &src0_row[ir0 * nb01..];
|
||||
let v = T::vec_dot(ne00, src0_row, src1_col)?;
|
||||
dst_col[ir0 - iir0] += v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
|
||||
@ -528,6 +843,7 @@ pub enum GgmlDType {
|
||||
Q4K,
|
||||
Q5K,
|
||||
Q6K,
|
||||
Q8K,
|
||||
}
|
||||
|
||||
impl GgmlDType {
|
||||
@ -546,6 +862,7 @@ impl GgmlDType {
|
||||
12 => Self::Q4K,
|
||||
13 => Self::Q5K,
|
||||
14 => Self::Q6K,
|
||||
15 => Self::Q8K,
|
||||
_ => crate::bail!("unknown dtype for tensor {u}"),
|
||||
};
|
||||
Ok(dtype)
|
||||
@ -567,6 +884,7 @@ impl GgmlDType {
|
||||
Self::Q4K => std::mem::size_of::<BlockQ4K>(),
|
||||
Self::Q5K => std::mem::size_of::<BlockQ5K>(),
|
||||
Self::Q6K => std::mem::size_of::<BlockQ6K>(),
|
||||
Self::Q8K => std::mem::size_of::<BlockQ8K>(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -580,27 +898,23 @@ impl GgmlDType {
|
||||
Self::Q5_1 => QK5_1,
|
||||
Self::Q8_0 => QK8_0,
|
||||
Self::Q8_1 => QK8_1,
|
||||
Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K => QK_K,
|
||||
Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => QK_K,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn dequantize_and_create_tensor<T, F>(
|
||||
fn dequantize_and_create_tensor<T: GgmlType>(
|
||||
raw_data: &[u8],
|
||||
tensor_elems: usize,
|
||||
size_in_bytes: usize,
|
||||
dims: Vec<usize>,
|
||||
device: &Device,
|
||||
dequantize_row: F,
|
||||
) -> Result<Tensor>
|
||||
where
|
||||
F: Fn(&[T], &mut [f32]) -> Result<()>,
|
||||
{
|
||||
) -> Result<Tensor> {
|
||||
let mut f32_data = vec![0f32; tensor_elems];
|
||||
let raw_data_ptr = raw_data.as_ptr();
|
||||
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
||||
let raw_data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||
dequantize_row(raw_data, &mut f32_data)?;
|
||||
T::to_float(raw_data, &mut f32_data)?;
|
||||
Tensor::from_vec(f32_data, dims, device)
|
||||
}
|
||||
|
||||
@ -618,87 +932,76 @@ pub fn tensor_from_ggml(
|
||||
let tensor = match ggml_dtype {
|
||||
GgmlDType::F32 => Tensor::from_raw_buffer(raw_data, DType::F32, &dims, device),
|
||||
GgmlDType::F16 => Tensor::from_raw_buffer(raw_data, DType::F16, &dims, device),
|
||||
GgmlDType::Q4_0 => dequantize_and_create_tensor(
|
||||
GgmlDType::Q4_0 => dequantize_and_create_tensor::<BlockQ4_0>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
dequantize_row_q4_0,
|
||||
),
|
||||
GgmlDType::Q4_1 => dequantize_and_create_tensor(
|
||||
GgmlDType::Q4_1 => dequantize_and_create_tensor::<BlockQ4_1>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
dequantize_row_q4_1,
|
||||
),
|
||||
GgmlDType::Q5_0 => dequantize_and_create_tensor(
|
||||
GgmlDType::Q5_0 => dequantize_and_create_tensor::<BlockQ5_0>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
dequantize_row_q5_0,
|
||||
),
|
||||
GgmlDType::Q5_1 => dequantize_and_create_tensor(
|
||||
GgmlDType::Q5_1 => dequantize_and_create_tensor::<BlockQ5_1>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
dequantize_row_q5_1,
|
||||
),
|
||||
GgmlDType::Q8_0 => dequantize_and_create_tensor(
|
||||
GgmlDType::Q8_0 => dequantize_and_create_tensor::<BlockQ8_0>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
dequantize_row_q8_0,
|
||||
),
|
||||
GgmlDType::Q2K => dequantize_and_create_tensor(
|
||||
GgmlDType::Q2K => dequantize_and_create_tensor::<BlockQ2K>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
dequantize_row_q2k,
|
||||
),
|
||||
GgmlDType::Q3K => dequantize_and_create_tensor(
|
||||
GgmlDType::Q3K => dequantize_and_create_tensor::<BlockQ3K>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
dequantize_row_q3k,
|
||||
),
|
||||
GgmlDType::Q4K => dequantize_and_create_tensor(
|
||||
GgmlDType::Q4K => dequantize_and_create_tensor::<BlockQ4K>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
dequantize_row_q4k,
|
||||
),
|
||||
GgmlDType::Q5K => dequantize_and_create_tensor(
|
||||
GgmlDType::Q5K => dequantize_and_create_tensor::<BlockQ5K>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
dequantize_row_q5k,
|
||||
),
|
||||
GgmlDType::Q6K => dequantize_and_create_tensor(
|
||||
GgmlDType::Q6K => dequantize_and_create_tensor::<BlockQ6K>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
dequantize_row_q6k,
|
||||
),
|
||||
|
||||
_ => crate::bail!("quantized type {dtype:?} is not supported yet"),
|
||||
}?;
|
||||
//We only have ggml-quant to f32 conversions, meaning we have to convert to the desired type
|
||||
|
Reference in New Issue
Block a user