Add reference implementation for q4k and q5k (#586)

* add `q2k` vec-dot

* `q3k` vec-dot + quantization bugfix

* `q4k` vec-dot

* `q5k` vec-dot

* Validate against GGML unit test results.

* Remove some more `transmutes`
This commit is contained in:
Lukas Kreussel
2023-08-26 13:07:54 +02:00
committed by GitHub
parent 864227edbf
commit c72eb3d75b
2 changed files with 270 additions and 5 deletions

View File

@ -906,8 +906,91 @@ impl GgmlType for BlockQ4K {
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
}
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
let mut utmp: [u32; 4] = [0; 4];
let mut scales: [u8; 8];
let mut mins: [u8; 8];
let mut aux8: [i8; QK_K] = [0; QK_K];
let mut aux16: [i16; 8] = [0; 8];
let mut sums: [f32; 8] = [0.0; 8];
let mut aux32: [i32; 8] = [0; 8];
let mut sumf = 0.0;
for (y, x) in ys.iter().zip(xs.iter()) {
let q4 = &x.qs;
let q8 = &y.qs;
aux32.fill(0);
let mut a = &mut aux8[..];
let mut q4 = &q4[..];
for _ in 0..QK_K / 64 {
for l in 0..32 {
a[l] = (q4[l] & 0xF) as i8;
}
a = &mut a[32..];
for l in 0..32 {
a[l] = (q4[l] >> 4) as i8;
}
a = &mut a[32..];
q4 = &q4[32..];
}
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
let uaux = utmp[1] & KMASK1;
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[2] = uaux;
utmp[0] &= KMASK1;
//extract scales and mins
let mut utemp_scales = &mut [0u32; 2];
let mut utemp_mins = &mut [0u32; 2];
utemp_scales.copy_from_slice(&utmp[0..2]);
utemp_mins.copy_from_slice(&utmp[2..4]);
scales =
unsafe { *std::mem::transmute::<&mut [u32; 2], &mut [u8; 8]>(&mut utemp_scales) };
mins = unsafe { *std::mem::transmute::<&mut [u32; 2], &mut [u8; 8]>(&mut utemp_mins) };
let mut sumi = 0;
for j in 0..QK_K / 16 {
sumi += y.bsums[j] as i32 * mins[j / 2] as i32;
}
let mut a = &mut aux8[..];
let mut q8 = &q8[..];
for scale in scales {
let scale = scale as i32;
for _ in 0..4 {
for l in 0..8 {
aux16[l] = q8[l] as i16 * a[l] as i16;
}
for l in 0..8 {
aux32[l] += scale * aux16[l] as i32;
}
q8 = &q8[8..];
a = &mut a[8..];
}
}
let d = x.d.to_f32() * y.d;
for l in 0..8 {
sums[l] += d * aux32[l] as f32;
}
let dmin = x.dmin.to_f32() * y.d;
sumf -= dmin * sumi as f32;
}
Ok(sumf + sums.iter().sum::<f32>())
}
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
@ -1008,8 +1091,98 @@ impl GgmlType for BlockQ5K {
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
todo!()
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
}
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
let mut utmp: [u32; 4] = [0; 4];
let mut scales: [u8; 8];
let mut mins: [u8; 8];
let mut aux8: [i8; QK_K] = [0; QK_K];
let mut aux16: [i16; 8] = [0; 8];
let mut sums: [f32; 8] = [0.0; 8];
let mut aux32: [i32; 8] = [0; 8];
let mut sumf = 0.0;
for (y, x) in ys.iter().zip(xs.iter()) {
let q5 = &x.qs;
let hm = &x.qh;
let q8 = &y.qs;
aux32.fill(0);
let mut a = &mut aux8[..];
let mut q5 = &q5[..];
let mut m = 1u8;
for _ in 0..QK_K / 64 {
for l in 0..32 {
a[l] = (q5[l] & 0xF) as i8;
a[l] += if hm[l] & m != 0 { 16 } else { 0 };
}
a = &mut a[32..];
m <<= 1;
for l in 0..32 {
a[l] = (q5[l] >> 4) as i8;
a[l] += if hm[l] & m != 0 { 16 } else { 0 };
}
a = &mut a[32..];
m <<= 1;
q5 = &q5[32..];
}
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
let uaux = utmp[1] & KMASK1;
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[2] = uaux;
utmp[0] &= KMASK1;
//extract scales and mins
let mut utemp_scales = &mut [0u32; 2];
let mut utemp_mins = &mut [0u32; 2];
utemp_scales.copy_from_slice(&utmp[0..2]);
utemp_mins.copy_from_slice(&utmp[2..4]);
scales =
unsafe { *std::mem::transmute::<&mut [u32; 2], &mut [u8; 8]>(&mut utemp_scales) };
mins = unsafe { *std::mem::transmute::<&mut [u32; 2], &mut [u8; 8]>(&mut utemp_mins) };
let mut sumi = 0;
for j in 0..QK_K / 16 {
sumi += y.bsums[j] as i32 * mins[j / 2] as i32;
}
let mut a = &mut aux8[..];
let mut q8 = &q8[..];
for scale in scales {
let scale = scale as i32;
for _ in 0..4 {
for l in 0..8 {
aux16[l] = q8[l] as i16 * a[l] as i16;
}
for l in 0..8 {
aux32[l] += scale * aux16[l] as i32;
}
q8 = &q8[8..];
a = &mut a[8..];
}
}
let d = x.d.to_f32() * y.d;
for l in 0..8 {
sums[l] += d * aux32[l] as f32;
}
let dmin = x.dmin.to_f32() * y.d;
sumf -= dmin * sumi as f32;
}
Ok(sumf + sums.iter().sum::<f32>())
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L793