Simd128 q2k vecdot (#982)

* Sketch the simd128 version of q2k vecdot.

* Use a single accumulator.
This commit is contained in:
Laurent Mazare
2023-09-28 12:16:35 +01:00
committed by GitHub
parent 5e1c595e00
commit 25657804ef
2 changed files with 57 additions and 4 deletions

View File

@ -687,6 +687,9 @@ impl GgmlType for BlockQ2K {
#[cfg(target_feature = "neon")] #[cfg(target_feature = "neon")]
return super::neon::vec_dot_q2k_q8k(n, xs, ys); return super::neon::vec_dot_q2k_q8k(n, xs, ys);
#[cfg(target_feature = "simd128")]
return super::simd128::vec_dot_q2k_q8k(n, xs, ys);
if n % QK_K != 0 { if n % QK_K != 0 {
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}") crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
} }

View File

@ -1,4 +1,4 @@
use super::k_quants::{BlockQ4K, BlockQ4_0, BlockQ8K, BlockQ8_0, QK8_0, QK_K}; use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
use crate::Result; use crate::Result;
use byteorder::{ByteOrder, LittleEndian}; use byteorder::{ByteOrder, LittleEndian};
use half::f16; use half::f16;
@ -99,6 +99,58 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
} }
} }
#[inline(always)]
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
}
let mut sumf = 0.0;
for (x, y) in xs.iter().zip(ys.iter()) {
let mut q2: &[_] = &x.qs;
let mut q8: &[_] = &y.qs;
let sc = &x.scales;
let mut summs = 0;
for (bsum, scale) in y.bsums.iter().zip(sc) {
summs += *bsum as i32 * ((scale >> 4) as i32);
}
let dall = y.d * x.d.to_f32();
let dmin = y.d * x.dmin.to_f32();
let mut isum = 0;
let mut is = 0;
let mut d;
for _ in 0..(QK_K / 128) {
let mut shift = 0;
for _ in 0..4 {
d = (sc[is] & 0xF) as i32;
is += 1;
let mut isuml = 0;
for l in 0..16 {
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
}
isum += d * isuml;
d = (sc[is] & 0xF) as i32;
is += 1;
isuml = 0;
for l in 16..32 {
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
}
isum += d * isuml;
shift += 2;
// adjust the indexing
q8 = &q8[32..];
}
// adjust the indexing
q2 = &q2[32..];
}
sumf += dall * isum as f32 - dmin * summs as f32;
}
Ok(sumf)
}
#[inline(always)] #[inline(always)]
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> { pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 { if n % QK_K != 0 {
@ -115,7 +167,6 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
let mut aux8: [u8; QK_K] = [0; QK_K]; let mut aux8: [u8; QK_K] = [0; QK_K];
let mut sums = f32x4_splat(0f32); let mut sums = f32x4_splat(0f32);
let mut sumf = f32x4_splat(0f32);
unsafe { unsafe {
for (y, x) in ys.iter().zip(xs.iter()) { for (y, x) in ys.iter().zip(xs.iter()) {
let q4 = &x.qs; let q4 = &x.qs;
@ -180,9 +231,8 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
let dmin = x.dmin.to_f32() * y.d; let dmin = x.dmin.to_f32() * y.d;
let dmin = f32x4_splat(dmin); let dmin = f32x4_splat(dmin);
let sumi = f32x4_convert_i32x4(sumi); let sumi = f32x4_convert_i32x4(sumi);
sumf = f32x4_add(sumf, f32x4_mul(sumi, dmin)); sums = f32x4_sub(sums, f32x4_mul(sumi, dmin));
} }
let sums = f32x4_sub(sums, sumf);
let sums = f32x4_extract_lane::<0>(sums) let sums = f32x4_extract_lane::<0>(sums)
+ f32x4_extract_lane::<1>(sums) + f32x4_extract_lane::<1>(sums)
+ f32x4_extract_lane::<2>(sums) + f32x4_extract_lane::<2>(sums)