diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index ac3f7def..80d36555 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1756,14 +1756,18 @@ impl GgmlType for BlockQ8K { const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; + #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q8k_q8k(n, xs, ys); + Self::vec_dot_unopt(n, xs, ys) } fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - let qk = QK8_0; - if n % QK8_0 != 0 { - crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") + let qk = QK_K; + if n % QK_K != 0 { + crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}") } // Generic implementation. diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index 7f76dadc..fd4c1388 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -148,6 +148,35 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> } } +#[inline(always)] +pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result { + let qk = QK_K; + if n % QK_K != 0 { + crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}") + } + + let mut sumf = 0f32; + for (xs, ys) in xs.iter().zip(ys.iter()) { + unsafe { + let mut sum_i = vdupq_n_s32(0); + let scale = xs.d * ys.d; + let xs = xs.qs.as_ptr(); + let ys = ys.qs.as_ptr(); + for i in (0..QK_K).step_by(16) { + let xs = vld1q_s8(xs.add(i)); + let ys = vld1q_s8(ys.add(i)); + let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys)); + let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys)); + + let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up)); + sum_i = vaddq_s32(sum_i, xy) + } + sumf += vaddvq_s32(sum_i) as f32 * scale + } + } + Ok(sumf) +} + #[inline(always)] pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 {