From 9c8d6dbc2a0f01c5e5d33fcce172a6ab2f9a0ec7 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 25 Aug 2023 14:42:18 +0100 Subject: [PATCH] Neon intrinsics for the q8_0 vecdot. (#604) * Neon intrinsics for the q8_0 vecdot. * Get the tests to run with accelerate (with some numerical error failures). --- candle-core/src/lib.rs | 3 ++ candle-core/src/quantized/k_quants.rs | 3 ++ candle-core/src/quantized/neon.rs | 61 +++++++++++++++++++++++++++ 3 files changed, 67 insertions(+) diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 3622d22e..ddd446ee 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -87,3 +87,6 @@ pub use dummy_cuda_backend::{CudaDevice, CudaStorage}; #[cfg(feature = "mkl")] extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 02022480..3aefa5df 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -426,6 +426,9 @@ impl GgmlType for BlockQ8_0 { #[cfg(target_feature = "avx")] return super::avx::vec_dot_q8_0_q8_0(n, xs, ys); + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q8_0_q8_0(n, xs, ys); + let qk = QK8_0; if n % QK8_0 != 0 { crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index a4d70350..32c93af4 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -84,6 +84,67 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> } } +#[inline(always)] +pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") + } + let nb = n / QK8_0; + if nb % 2 != 0 { + crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even") + } + unsafe { + let mut sumv0 = vdupq_n_f32(0.0f32); + let mut sumv1 = vdupq_n_f32(0.0f32); + for i in (0..nb).step_by(2) { + let x0 = &xs[i]; + let x1 = &xs[i + 1]; + let y0 = &ys[i]; + let y1 = &ys[i + 1]; + + let x0_0 = vld1q_s8(x0.qs.as_ptr()); + let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16)); + let x1_0 = vld1q_s8(x1.qs.as_ptr()); + let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16)); + + // load y + let y0_0 = vld1q_s8(y0.qs.as_ptr()); + let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); + let y1_0 = vld1q_s8(y1.qs.as_ptr()); + let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16)); + + // TODO dotprod once this is the intrinsics are. + let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0)); + let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0)); + let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1)); + let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); + + let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0)); + let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0)); + let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1)); + let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1)); + + let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); + let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); + let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1)); + let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3)); + + sumv0 = vmlaq_n_f32( + sumv0, + vcvtq_f32_s32(vaddq_s32(p0, p1)), + x0.d.to_f32() * y0.d.to_f32(), + ); + sumv1 = vmlaq_n_f32( + sumv1, + vcvtq_f32_s32(vaddq_s32(p2, p3)), + x1.d.to_f32() * y1.d.to_f32(), + ); + } + Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1)) + } +} + #[inline(always)] pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 {