diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 5b5ea4b0..9a72d88e 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -687,6 +687,9 @@ impl GgmlType for BlockQ2K { #[cfg(target_feature = "neon")] return super::neon::vec_dot_q2k_q8k(n, xs, ys); + #[cfg(target_feature = "simd128")] + return super::simd128::vec_dot_q2k_q8k(n, xs, ys); + if n % QK_K != 0 { crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}") } diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs index 6b225cce..061421c4 100644 --- a/candle-core/src/quantized/simd128.rs +++ b/candle-core/src/quantized/simd128.rs @@ -1,4 +1,4 @@ -use super::k_quants::{BlockQ4K, BlockQ4_0, BlockQ8K, BlockQ8_0, QK8_0, QK_K}; +use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ8K, BlockQ8_0, QK8_0, QK_K}; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; use half::f16; @@ -99,6 +99,58 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> } } +#[inline(always)] +pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}") + } + let mut sumf = 0.0; + for (x, y) in xs.iter().zip(ys.iter()) { + let mut q2: &[_] = &x.qs; + let mut q8: &[_] = &y.qs; + let sc = &x.scales; + + let mut summs = 0; + for (bsum, scale) in y.bsums.iter().zip(sc) { + summs += *bsum as i32 * ((scale >> 4) as i32); + } + + let dall = y.d * x.d.to_f32(); + let dmin = y.d * x.dmin.to_f32(); + + let mut isum = 0; + let mut is = 0; + let mut d; + for _ in 0..(QK_K / 128) { + let mut shift = 0; + for _ in 0..4 { + d = (sc[is] & 0xF) as i32; + is += 1; + let mut isuml = 0; + for l in 0..16 { + isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32); + } + isum += d * isuml; + d = (sc[is] & 0xF) as i32; + is += 1; + isuml = 0; + for l in 16..32 { + isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32); + } + isum += d * isuml; + shift += 2; + // adjust the indexing + q8 = &q8[32..]; + } + // adjust the indexing + q2 = &q2[32..]; + } + sumf += dall * isum as f32 - dmin * summs as f32; + } + + Ok(sumf) +} + #[inline(always)] pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 { @@ -115,7 +167,6 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res let mut aux8: [u8; QK_K] = [0; QK_K]; let mut sums = f32x4_splat(0f32); - let mut sumf = f32x4_splat(0f32); unsafe { for (y, x) in ys.iter().zip(xs.iter()) { let q4 = &x.qs; @@ -180,9 +231,8 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res let dmin = x.dmin.to_f32() * y.d; let dmin = f32x4_splat(dmin); let sumi = f32x4_convert_i32x4(sumi); - sumf = f32x4_add(sumf, f32x4_mul(sumi, dmin)); + sums = f32x4_sub(sums, f32x4_mul(sumi, dmin)); } - let sums = f32x4_sub(sums, sumf); let sums = f32x4_extract_lane::<0>(sums) + f32x4_extract_lane::<1>(sums) + f32x4_extract_lane::<2>(sums)