mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Referenze implementations of q2k
and q3k
vec-dot functions (#580)
* add `q2k` vec-dot * `q3k` vec-dot + quantization bugfix
This commit is contained in:
@ -450,8 +450,56 @@ impl GgmlType for BlockQ2K {
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
|
||||
let mut sumf = 0.0;
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let mut q2: &[_] = &x.qs;
|
||||
let mut q8: &[_] = &y.qs;
|
||||
let sc = &x.scales;
|
||||
|
||||
let mut summs = 0;
|
||||
for (bsum, scale) in y.bsums.iter().zip(sc) {
|
||||
summs += *bsum as i32 * ((scale >> 4) as i32);
|
||||
}
|
||||
|
||||
let dall = y.d * x.d.to_f32();
|
||||
let dmin = y.d * x.dmin.to_f32();
|
||||
|
||||
let mut isum = 0;
|
||||
let mut is = 0;
|
||||
let mut d;
|
||||
for _ in 0..(QK_K / 128) {
|
||||
let mut shift = 0;
|
||||
for _ in 0..4 {
|
||||
d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
let mut isuml = 0;
|
||||
for l in 0..16 {
|
||||
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
|
||||
}
|
||||
isum += d * isuml;
|
||||
d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
isuml = 0;
|
||||
for l in 16..32 {
|
||||
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
|
||||
}
|
||||
isum += d * isuml;
|
||||
shift += 2;
|
||||
// adjust the indexing
|
||||
q8 = &q8[32..];
|
||||
}
|
||||
// adjust the indexing
|
||||
q2 = &q2[32..];
|
||||
}
|
||||
sumf += dall * isum as f32 - dmin * summs as f32;
|
||||
}
|
||||
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L279
|
||||
@ -565,8 +613,129 @@ impl GgmlType for BlockQ3K {
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
|
||||
const KMASK1: u32 = 0x03030303;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
|
||||
let mut aux8: [i8; QK_K] = [0; QK_K];
|
||||
let mut aux16: [i16; 8] = [0; 8];
|
||||
let mut sums: [f32; 8] = [0.0; 8];
|
||||
let mut aux32: [i32; 8] = [0; 8];
|
||||
|
||||
let mut auxs: [u32; 4] = [0; 4];
|
||||
let mut scales: &[i8; 16];
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let mut q3: &[u8] = &x.qs;
|
||||
let hmask: &[u8] = &x.hmask;
|
||||
let mut q8: &[i8] = &y.qs;
|
||||
|
||||
aux32.iter_mut().for_each(|x| *x = 0);
|
||||
let mut a = &mut aux8[..];
|
||||
|
||||
let mut m = 1;
|
||||
//Like the GGML original this is written this way to enable the compiler to vectorize it.
|
||||
for _ in 0..QK_K / 128 {
|
||||
a.iter_mut()
|
||||
.take(32)
|
||||
.zip(q3)
|
||||
.for_each(|(a_val, q3_val)| *a_val = (q3_val & 3) as i8);
|
||||
a.iter_mut()
|
||||
.take(32)
|
||||
.zip(hmask)
|
||||
.for_each(|(a_val, hmask_val)| {
|
||||
*a_val -= if hmask_val & m != 0 { 0 } else { 4 }
|
||||
});
|
||||
a = &mut a[32..];
|
||||
m <<= 1;
|
||||
|
||||
a.iter_mut()
|
||||
.take(32)
|
||||
.zip(q3)
|
||||
.for_each(|(a_val, q3_val)| *a_val = ((q3_val >> 2) & 3) as i8);
|
||||
a.iter_mut()
|
||||
.take(32)
|
||||
.zip(hmask)
|
||||
.for_each(|(a_val, hmask_val)| {
|
||||
*a_val -= if hmask_val & m != 0 { 0 } else { 4 }
|
||||
});
|
||||
a = &mut a[32..];
|
||||
m <<= 1;
|
||||
|
||||
a.iter_mut()
|
||||
.take(32)
|
||||
.zip(q3)
|
||||
.for_each(|(a_val, q3_val)| *a_val = ((q3_val >> 4) & 3) as i8);
|
||||
a.iter_mut()
|
||||
.take(32)
|
||||
.zip(hmask)
|
||||
.for_each(|(a_val, hmask_val)| {
|
||||
*a_val -= if hmask_val & m != 0 { 0 } else { 4 }
|
||||
});
|
||||
a = &mut a[32..];
|
||||
m <<= 1;
|
||||
|
||||
a.iter_mut()
|
||||
.take(32)
|
||||
.zip(q3)
|
||||
.for_each(|(a_val, q3_val)| *a_val = ((q3_val >> 6) & 3) as i8);
|
||||
a.iter_mut()
|
||||
.take(32)
|
||||
.zip(hmask)
|
||||
.for_each(|(a_val, hmask_val)| {
|
||||
*a_val -= if hmask_val & m != 0 { 0 } else { 4 }
|
||||
});
|
||||
a = &mut a[32..];
|
||||
m <<= 1;
|
||||
q3 = &q3[32..];
|
||||
}
|
||||
|
||||
a = &mut aux8[..];
|
||||
|
||||
let aux_raw = unsafe {
|
||||
std::mem::transmute::<&mut [u8; 12], &mut [u32; 3]>(&mut x.scales.clone())
|
||||
};
|
||||
auxs[0..3].copy_from_slice(aux_raw);
|
||||
|
||||
let tmp = auxs[2];
|
||||
auxs[2] = ((auxs[0] >> 4) & KMASK2) | (((tmp >> 4) & KMASK1) << 4);
|
||||
auxs[3] = ((auxs[1] >> 4) & KMASK2) | (((tmp >> 6) & KMASK1) << 4);
|
||||
auxs[0] = (auxs[0] & KMASK2) | (((tmp) & KMASK1) << 4);
|
||||
auxs[1] = (auxs[1] & KMASK2) | (((tmp >> 2) & KMASK1) << 4);
|
||||
|
||||
scales = unsafe { std::mem::transmute::<&mut [u32; 4], &mut [i8; 16]>(&mut auxs) };
|
||||
|
||||
for scale in scales {
|
||||
for l in 0..8 {
|
||||
aux16[l] = q8[l] as i16 * a[l] as i16;
|
||||
}
|
||||
for l in 0..8 {
|
||||
aux32[l] += (*scale as i32 - 32) * aux16[l] as i32;
|
||||
}
|
||||
q8 = &q8[8..];
|
||||
a = &mut a[8..];
|
||||
|
||||
for l in 0..8 {
|
||||
aux16[l] = q8[l] as i16 * a[l] as i16;
|
||||
}
|
||||
for l in 0..8 {
|
||||
aux32[l] += (*scale as i32 - 32) * aux16[l] as i32;
|
||||
}
|
||||
q8 = &q8[8..];
|
||||
a = &mut a[8..];
|
||||
}
|
||||
|
||||
let d = x.d.to_f32() * y.d;
|
||||
for l in 0..8 {
|
||||
sums[l] += d * aux32[l] as f32;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(sums.iter().sum())
|
||||
}
|
||||
|
||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||
@ -577,9 +746,12 @@ impl GgmlType for BlockQ3K {
|
||||
}
|
||||
|
||||
// Get max scale by absolute value.
|
||||
let max_scale = scales
|
||||
.iter()
|
||||
.fold(0.0, |max, &val| if val.abs() > max { val } else { max });
|
||||
let mut max_scale: f32 = 0.0;
|
||||
for &scale in scales.iter() {
|
||||
if scale.abs() > max_scale.abs() {
|
||||
max_scale = scale;
|
||||
}
|
||||
}
|
||||
|
||||
block.scales.fill(0);
|
||||
|
||||
|
@ -444,6 +444,60 @@ fn get_random_tensors(
|
||||
Ok((lhs, rhs, mm))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q2k() -> Result<()> {
|
||||
use k_quants::BlockQ2K;
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [0.916, 0.422, 0.215, 1.668]);
|
||||
|
||||
//mirrored GGML unit test
|
||||
ggml_matmul_error_test::<BlockQ2K>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q3k() -> Result<()> {
|
||||
use k_quants::BlockQ3K;
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.029, 1.418, -0.314, 1.495]);
|
||||
|
||||
//mirrored GGML unit test
|
||||
ggml_matmul_error_test::<BlockQ3K>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q6k() -> Result<()> {
|
||||
use k_quants::BlockQ6K;
|
||||
|
Reference in New Issue
Block a user