Basic quantization support (#453)

* Add a vecdot trait.

* Start implementing mul_mat.

* Add to the mul mat implementation.

* Add q8_0 quantization.

* Implement the GgmlType trait for all types.

* Add the missing block.

* Add a TODO.
This commit is contained in:
Laurent Mazare
2023-08-15 15:53:19 +01:00
committed by GitHub
parent ebcfd96d94
commit 5e49922be2

View File

@ -15,15 +15,27 @@ pub const QK5_1: usize = 32;
pub const QK8_0: usize = 32;
pub const QK8_1: usize = 32;
pub trait GgmlType: Sized {
const DTYPE: GgmlDType;
const BLCK_SIZE: usize;
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()>;
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()>;
type VecDotType: GgmlType;
// Dot product used as a building block for quantized mat-mul.
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
}
#[repr(C)]
struct BlockQ4_0 {
pub struct BlockQ4_0 {
d: f16,
qs: [u8; QK4_0 / 2],
}
const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == 18);
#[repr(C)]
struct BlockQ4_1 {
pub struct BlockQ4_1 {
d: f16,
m: f16,
qs: [u8; QK4_1 / 2],
@ -31,7 +43,7 @@ struct BlockQ4_1 {
const _: () = assert!(std::mem::size_of::<BlockQ4_1>() == 20);
#[repr(C)]
struct BlockQ5_0 {
pub struct BlockQ5_0 {
d: f16,
qh: [u8; 4],
qs: [u8; QK5_0 / 2],
@ -39,7 +51,7 @@ struct BlockQ5_0 {
const _: () = assert!(std::mem::size_of::<BlockQ5_0>() == 22);
#[repr(C)]
struct BlockQ5_1 {
pub struct BlockQ5_1 {
d: f16,
m: f16,
qh: [u8; 4],
@ -48,14 +60,14 @@ struct BlockQ5_1 {
const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24);
#[repr(C)]
struct BlockQ8_0 {
pub struct BlockQ8_0 {
d: f16,
qs: [u8; QK8_0],
}
const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
#[repr(C)]
struct BlockQ8_1 {
pub struct BlockQ8_1 {
d: f16,
s: f16,
qs: [u8; QK8_1],
@ -63,7 +75,7 @@ struct BlockQ8_1 {
const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
#[repr(C)]
struct BlockQ2K {
pub struct BlockQ2K {
scales: [u8; QK_K / 16],
qs: [u8; QK_K / 4],
d: f16,
@ -72,7 +84,7 @@ struct BlockQ2K {
const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::<BlockQ2K>());
#[repr(C)]
struct BlockQ3K {
pub struct BlockQ3K {
hmask: [u8; QK_K / 8],
qs: [u8; QK_K / 4],
scales: [u8; 12],
@ -82,7 +94,7 @@ const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82
#[repr(C)]
struct BlockQ4K {
pub struct BlockQ4K {
d: f16,
dmin: f16,
scales: [u8; K_SCALE_SIZE],
@ -91,7 +103,7 @@ struct BlockQ4K {
const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::<BlockQ4K>());
#[repr(C)]
struct BlockQ5K {
pub struct BlockQ5K {
d: f16,
dmin: f16,
scales: [u8; K_SCALE_SIZE],
@ -102,7 +114,7 @@ const _: () =
assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::<BlockQ5K>());
#[repr(C)]
struct BlockQ6K {
pub struct BlockQ6K {
ql: [u8; QK_K / 2],
qh: [u8; QK_K / 4],
scales: [i8; QK_K / 16],
@ -110,30 +122,29 @@ struct BlockQ6K {
}
const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::<BlockQ6K>());
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1525
fn dequantize_row_q4_0(xs: &[BlockQ4_0], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK4_0 != 0 {
crate::bail!("dequantize_row_q4_0: {k} is not divisible by {QK4_0}")
}
let nb = k / QK4_0;
for i in 0..nb {
let d = xs[i].d.to_f32();
for j in 0..(QK4_0 / 2) {
let x0 = (xs[i].qs[j] & 0x0F) as i16 - 8;
let x1 = (xs[i].qs[j] >> 4) as i16 - 8;
ys[i * QK4_0 + j] = (x0 as f32) * d;
ys[i * QK4_0 + j + QK4_0 / 2] = (x1 as f32) * d;
}
}
Ok(())
#[repr(C)]
pub struct BlockQ8K {
d: f32,
qs: [i8; QK_K],
bsums: [i16; QK_K / 16],
}
const _: () = assert!(4 + QK_K + QK_K / 16 * 2 == std::mem::size_of::<BlockQ8K>());
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1545
fn dequantize_row_q4_1(xs: &[BlockQ4_1], ys: &mut [f32]) -> Result<()> {
impl GgmlType for BlockQ4_1 {
const DTYPE: GgmlDType = GgmlDType::Q4_1;
const BLCK_SIZE: usize = QK4_1;
type VecDotType = BlockQ8_1;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1545
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK4_1 != 0 {
crate::bail!("dequantize_row_q4_1: {k} is not divisible by {QK4_1}");
@ -153,10 +164,24 @@ fn dequantize_row_q4_1(xs: &[BlockQ4_1], ys: &mut [f32]) -> Result<()> {
}
}
Ok(())
}
}
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1566
fn dequantize_row_q5_0(xs: &[BlockQ5_0], ys: &mut [f32]) -> Result<()> {
impl GgmlType for BlockQ5_0 {
const DTYPE: GgmlDType = GgmlDType::Q5_0;
const BLCK_SIZE: usize = QK5_0;
type VecDotType = BlockQ8_0;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1566
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK5_0 != 0 {
crate::bail!("dequantize_row_q5_0: {k} is not divisible by {QK5_0}");
@ -179,10 +204,24 @@ fn dequantize_row_q5_0(xs: &[BlockQ5_0], ys: &mut [f32]) -> Result<()> {
}
}
Ok(())
}
}
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1592
fn dequantize_row_q5_1(xs: &[BlockQ5_1], ys: &mut [f32]) -> Result<()> {
impl GgmlType for BlockQ5_1 {
const DTYPE: GgmlDType = GgmlDType::Q5_1;
const BLCK_SIZE: usize = QK5_1;
type VecDotType = BlockQ8_1;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1592
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK5_1 != 0 {
crate::bail!("dequantize_row_q5_1: {k} is not divisible by {QK5_1}");
@ -206,30 +245,23 @@ fn dequantize_row_q5_1(xs: &[BlockQ5_1], ys: &mut [f32]) -> Result<()> {
}
}
Ok(())
}
}
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1619
fn dequantize_row_q8_0(vx: &[BlockQ8_0], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK8_0 != 0 {
crate::bail!("dequantize_row_q8_0: {k} is not divisible by {QK8_0}");
impl GgmlType for BlockQ2K {
const DTYPE: GgmlDType = GgmlDType::Q2K;
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
let nb = k / QK8_0;
let xs: &[BlockQ8_0] = unsafe { std::mem::transmute(vx) };
for i in 0..nb {
let d = xs[i].d.to_f32();
for j in 0..QK8_0 {
ys[i * QK8_0 + j] = xs[i].qs[j] as f32 * d;
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
}
Ok(())
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354
fn dequantize_row_q2k(xs: &[BlockQ2K], ys: &mut [f32]) -> Result<()> {
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK_K != 0 {
crate::bail!("dequantize_row_q2k: {k} is not divisible by {QK_K}")
@ -271,6 +303,7 @@ fn dequantize_row_q2k(xs: &[BlockQ2K], ys: &mut [f32]) -> Result<()> {
}
}
Ok(())
}
}
fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
@ -284,8 +317,21 @@ fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
(d, m)
}
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735
fn dequantize_row_q4k(xs: &[BlockQ4K], ys: &mut [f32]) -> Result<()> {
impl GgmlType for BlockQ4K {
const DTYPE: GgmlDType = GgmlDType::Q4K;
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK_K != 0 {
crate::bail!("dequantize_row_q4k: {k} is not divisible by {QK_K}")
@ -318,15 +364,42 @@ fn dequantize_row_q4k(xs: &[BlockQ4K], ys: &mut [f32]) -> Result<()> {
}
}
Ok(())
}
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
fn dequantize_row_q3k(_xs: &[BlockQ3K], _ys: &mut [f32]) -> Result<()> {
impl GgmlType for BlockQ3K {
const DTYPE: GgmlDType = GgmlDType::Q3K;
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> {
todo!()
}
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928
fn dequantize_row_q5k(xs: &[BlockQ5K], ys: &mut [f32]) -> Result<()> {
impl GgmlType for BlockQ5K {
const DTYPE: GgmlDType = GgmlDType::Q5K;
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK_K != 0 {
crate::bail!("dequantize_row_q5k: {k} is not divisible by {QK_K}")
@ -366,10 +439,24 @@ fn dequantize_row_q5k(xs: &[BlockQ5K], ys: &mut [f32]) -> Result<()> {
}
}
Ok(())
}
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067
fn dequantize_row_q6k(xs: &[BlockQ6K], ys: &mut [f32]) -> Result<()> {
impl GgmlType for BlockQ6K {
const DTYPE: GgmlDType = GgmlDType::Q6K;
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK_K != 0 {
crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}")
@ -399,6 +486,234 @@ fn dequantize_row_q6k(xs: &[BlockQ6K], ys: &mut [f32]) -> Result<()> {
}
}
Ok(())
}
}
impl GgmlType for BlockQ8K {
const DTYPE: GgmlDType = GgmlDType::Q8K;
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> {
todo!()
}
}
impl GgmlType for BlockQ4_0 {
const DTYPE: GgmlDType = GgmlDType::Q4_0;
const BLCK_SIZE: usize = QK4_0;
type VecDotType = BlockQ8_0;
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1525
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK4_0 != 0 {
crate::bail!("dequantize_row_q4_0: {k} is not divisible by {QK4_0}")
}
let nb = k / QK4_0;
for i in 0..nb {
let d = xs[i].d.to_f32();
for j in 0..(QK4_0 / 2) {
let x0 = (xs[i].qs[j] & 0x0F) as i16 - 8;
let x1 = (xs[i].qs[j] >> 4) as i16 - 8;
ys[i * QK4_0 + j] = (x0 as f32) * d;
ys[i * QK4_0 + j + QK4_0 / 2] = (x1 as f32) * d;
}
}
Ok(())
}
fn from_float(_: &[f32], _: &mut [Self]) -> Result<()> {
todo!()
}
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L2361C10-L2361C122
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
// Generic implementation.
let mut sumf = 0f32;
for i in 0..nb {
let mut sum_i = 0;
for j in 0..qk / 2 {
let v0 = (xs[i].qs[j] & 0x0F) - 8;
let v1 = (xs[i].qs[j] >> 4) - 8;
sum_i += v0 * ys[i].qs[j] + v1 * ys[i].qs[j + qk / 2]
}
sumf += sum_i as f32 * f16::to_f32(xs[i].d) * f16::to_f32(ys[i].d)
}
Ok(sumf)
}
}
impl GgmlType for BlockQ8_0 {
const DTYPE: GgmlDType = GgmlDType::Q8_0;
const BLCK_SIZE: usize = QK8_0;
type VecDotType = BlockQ8_0;
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1619
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK8_0 != 0 {
crate::bail!("dequantize_row_q8_0: {k} is not divisible by {QK8_0}");
}
let nb = k / QK8_0;
for i in 0..nb {
let d = xs[i].d.to_f32();
for j in 0..QK8_0 {
ys[i * QK8_0 + j] = xs[i].qs[j] as f32 * d;
}
}
Ok(())
}
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
// quantize_row_q8_0
let k = ys.len();
if k % Self::BLCK_SIZE != 0 {
crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE);
};
let nb = k / Self::BLCK_SIZE;
if ys.len() != nb {
crate::bail!(
"size mismatch {} {} {}",
xs.len(),
ys.len(),
Self::BLCK_SIZE
)
}
for (i, ys) in ys.iter_mut().enumerate() {
let mut amax = 0f32;
let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
for &x in xs.iter() {
amax = amax.max(x)
}
let d = amax / ((1 << 7) - 1) as f32;
let id = if d != 0f32 { 1. / d } else { 0. };
ys.d = f16::from_f32(d);
for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) {
*y = f32::round(x * id) as u8
}
}
Ok(())
}
fn vec_dot(_: usize, _: &[Self], _: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
}
impl GgmlType for BlockQ8_1 {
const DTYPE: GgmlDType = GgmlDType::Q3K;
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8_1;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> {
todo!()
}
}
const BLCK0: usize = 16;
const BLCK1: usize = 16;
// This implementation is in-line with the ggml one and keeps the same variable names.
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605
pub fn forward_mul_mat<T: GgmlType>(src0: &[T], src1: &[f32], dst: &mut [f32]) -> Result<()> {
// TODO: Use the proper sizes here.
let (ne00, ne01, ne02, ne03) = (1, 1, 1, 1);
let (ne10, ne11, ne12, ne13) = (1, 1, 1, 1);
// The strides are in bytes in ggml, however we use the number of elements in candle.
let (_, nb1, nb2, nb3) = (1, 1, 1, 1);
let (_, nb01, nb02, nb03) = (1, 1, 1, 1);
let (_, nb11, nb12, nb13) = (1, 1, 1, 1);
let nr0 = ne01; // src0 rows
let nr1 = ne11 * ne12 * ne13;
// TODO: Either add multi-threading or remove these bits.
let ir010 = 0;
let ir011 = nr0;
let ir110 = 0;
let ir111 = nr1;
let r2 = ne12 / ne02;
let r3 = ne13 / ne03;
// TODO: Pre-allocate this.
let wdata = &mut [];
if ne10 % T::BLCK_SIZE != 0 {
crate::bail!(
"forward_mul_mat: ne10 {ne10} is not divisible by block size {}",
T::BLCK_SIZE
)
}
let row_size = ne10 / T::BLCK_SIZE;
for i13 in 0..ne13 {
for i12 in 0..ne12 {
for i11 in 0..ne11 {
let wdata_idx = i11 + i12 * ne11 + i13 * ne11 * ne12;
let wdata = &mut wdata[wdata_idx..wdata_idx + row_size];
let src1 = &src1[i13 * nb13 + i12 * nb12 + i11 * nb11..];
T::VecDotType::from_float(src1, wdata)?
}
}
}
for iir1 in (ir110..ir111).step_by(BLCK1) {
for iir0 in (ir010..ir011).step_by(BLCK0) {
for ir1 in iir1..usize::min(iir1 + BLCK1, ir111) {
let i13 = ir1 / (ne12 * ne11);
let i12 = (ir1 - i13 * ne12 * ne11) / ne11;
let i11 = ir1 - i13 * ne12 * ne11 - i12 * ne11;
let i03 = i13 / r3;
let i02 = i12 / r2;
let i1 = i11;
let i2 = i12;
let i3 = i13;
let src0_row = &src0[i02 * nb02 + i03 * nb03..];
let src1_col = &wdata[(i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size..];
let dst_col = &mut dst[i1 * nb1 + i2 * nb2 + i3 * nb3..];
for ir0 in iir0..usize::min(iir0 + BLCK0, ir011) {
let src0_row = &src0_row[ir0 * nb01..];
let v = T::vec_dot(ne00, src0_row, src1_col)?;
dst_col[ir0 - iir0] += v
}
}
}
}
Ok(())
}
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
@ -528,6 +843,7 @@ pub enum GgmlDType {
Q4K,
Q5K,
Q6K,
Q8K,
}
impl GgmlDType {
@ -546,6 +862,7 @@ impl GgmlDType {
12 => Self::Q4K,
13 => Self::Q5K,
14 => Self::Q6K,
15 => Self::Q8K,
_ => crate::bail!("unknown dtype for tensor {u}"),
};
Ok(dtype)
@ -567,6 +884,7 @@ impl GgmlDType {
Self::Q4K => std::mem::size_of::<BlockQ4K>(),
Self::Q5K => std::mem::size_of::<BlockQ5K>(),
Self::Q6K => std::mem::size_of::<BlockQ6K>(),
Self::Q8K => std::mem::size_of::<BlockQ8K>(),
}
}
@ -580,27 +898,23 @@ impl GgmlDType {
Self::Q5_1 => QK5_1,
Self::Q8_0 => QK8_0,
Self::Q8_1 => QK8_1,
Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K => QK_K,
Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => QK_K,
}
}
}
fn dequantize_and_create_tensor<T, F>(
fn dequantize_and_create_tensor<T: GgmlType>(
raw_data: &[u8],
tensor_elems: usize,
size_in_bytes: usize,
dims: Vec<usize>,
device: &Device,
dequantize_row: F,
) -> Result<Tensor>
where
F: Fn(&[T], &mut [f32]) -> Result<()>,
{
) -> Result<Tensor> {
let mut f32_data = vec![0f32; tensor_elems];
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
let raw_data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
dequantize_row(raw_data, &mut f32_data)?;
T::to_float(raw_data, &mut f32_data)?;
Tensor::from_vec(f32_data, dims, device)
}
@ -618,87 +932,76 @@ pub fn tensor_from_ggml(
let tensor = match ggml_dtype {
GgmlDType::F32 => Tensor::from_raw_buffer(raw_data, DType::F32, &dims, device),
GgmlDType::F16 => Tensor::from_raw_buffer(raw_data, DType::F16, &dims, device),
GgmlDType::Q4_0 => dequantize_and_create_tensor(
GgmlDType::Q4_0 => dequantize_and_create_tensor::<BlockQ4_0>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
dequantize_row_q4_0,
),
GgmlDType::Q4_1 => dequantize_and_create_tensor(
GgmlDType::Q4_1 => dequantize_and_create_tensor::<BlockQ4_1>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
dequantize_row_q4_1,
),
GgmlDType::Q5_0 => dequantize_and_create_tensor(
GgmlDType::Q5_0 => dequantize_and_create_tensor::<BlockQ5_0>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
dequantize_row_q5_0,
),
GgmlDType::Q5_1 => dequantize_and_create_tensor(
GgmlDType::Q5_1 => dequantize_and_create_tensor::<BlockQ5_1>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
dequantize_row_q5_1,
),
GgmlDType::Q8_0 => dequantize_and_create_tensor(
GgmlDType::Q8_0 => dequantize_and_create_tensor::<BlockQ8_0>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
dequantize_row_q8_0,
),
GgmlDType::Q2K => dequantize_and_create_tensor(
GgmlDType::Q2K => dequantize_and_create_tensor::<BlockQ2K>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
dequantize_row_q2k,
),
GgmlDType::Q3K => dequantize_and_create_tensor(
GgmlDType::Q3K => dequantize_and_create_tensor::<BlockQ3K>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
dequantize_row_q3k,
),
GgmlDType::Q4K => dequantize_and_create_tensor(
GgmlDType::Q4K => dequantize_and_create_tensor::<BlockQ4K>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
dequantize_row_q4k,
),
GgmlDType::Q5K => dequantize_and_create_tensor(
GgmlDType::Q5K => dequantize_and_create_tensor::<BlockQ5K>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
dequantize_row_q5k,
),
GgmlDType::Q6K => dequantize_and_create_tensor(
GgmlDType::Q6K => dequantize_and_create_tensor::<BlockQ6K>(
raw_data,
tensor_elems,
size_in_bytes,
dims,
device,
dequantize_row_q6k,
),
_ => crate::bail!("quantized type {dtype:?} is not supported yet"),
}?;
//We only have ggml-quant to f32 conversions, meaning we have to convert to the desired type