From ef33df7ae2b94e2b911b61f3765d6826726614e7 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 28 Oct 2023 08:23:59 +0200 Subject: [PATCH] No need for the even constraint on vecdot-q40-q80. (#1202) --- candle-core/src/quantized/avx.rs | 5 ----- candle-core/src/quantized/k_quants.rs | 5 ----- candle-core/src/quantized/neon.rs | 29 ++------------------------- candle-core/src/quantized/simd128.rs | 4 ---- 4 files changed, 2 insertions(+), 41 deletions(-) diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs index d4b05bb0..5c3ac822 100644 --- a/candle-core/src/quantized/avx.rs +++ b/candle-core/src/quantized/avx.rs @@ -50,14 +50,9 @@ pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 { #[inline(always)] pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result { let qk = QK8_0; - let nb = n / qk; if n % QK8_0 != 0 { crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") } - if nb % 2 != 0 { - crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even") - } - unsafe { let mut acc = _mm256_setzero_ps(); for (x, y) in xs.iter().zip(ys.iter()) { diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index b140131e..d16289e6 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -236,14 +236,9 @@ impl GgmlType for BlockQ4_0 { fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { let qk = QK8_0; - let nb = n / qk; if n % QK8_0 != 0 { crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") } - if nb % 2 != 0 { - crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even") - } - // Generic implementation. let mut sumf = 0f32; for (xs, ys) in xs.iter().zip(ys.iter()) { diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index 51bd3e66..3cb56229 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -19,42 +19,29 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> if n % QK8_0 != 0 { crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") } - if nb % 2 != 0 { - crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even") - } unsafe { let mut sumv0 = vdupq_n_f32(0.0f32); - let mut sumv1 = vdupq_n_f32(0.0f32); - for i in (0..nb).step_by(2) { + for i in 0..nb { let x0 = &xs[i]; - let x1 = &xs[i + 1]; let y0 = &ys[i]; - let y1 = &ys[i + 1]; let m4b = vdupq_n_u8(0x0F); let s8b = vdupq_n_s8(0x8); let v0_0 = vld1q_u8(x0.qs.as_ptr()); - let v0_1 = vld1q_u8(x1.qs.as_ptr()); // 4-bit -> 8-bit let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b)); let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - let v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b)); - let v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); // sub 8 let v0_0ls = vsubq_s8(v0_0l, s8b); let v0_0hs = vsubq_s8(v0_0h, s8b); - let v0_1ls = vsubq_s8(v0_1l, s8b); - let v0_1hs = vsubq_s8(v0_1h, s8b); // load y let v1_0l = vld1q_s8(y0.qs.as_ptr()); let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16)); - let v1_1l = vld1q_s8(y1.qs.as_ptr()); - let v1_1h = vld1q_s8(y1.qs.as_ptr().add(16)); // TODO: Support dotprod when it's available outside of nightly. let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l)); @@ -62,28 +49,16 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h)); let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); - let pl1l = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1l)); - let pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l)); - let ph1l = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1h)); - let ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h)); - let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - let pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - let ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); sumv0 = vmlaq_n_f32( sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0.d.to_f32() * y0.d.to_f32(), ); - sumv1 = vmlaq_n_f32( - sumv1, - vcvtq_f32_s32(vaddq_s32(pl1, ph1)), - x1.d.to_f32() * y1.d.to_f32(), - ); } - Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1)) + Ok(vaddvq_f32(sumv0)) } } diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs index f256fdc2..1c8c0f20 100644 --- a/candle-core/src/quantized/simd128.rs +++ b/candle-core/src/quantized/simd128.rs @@ -11,10 +11,6 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> if n % QK8_0 != 0 { crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") } - let nb = n / QK8_0; - if nb % 2 != 0 { - crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even") - } unsafe { let mut acc = f32x4_splat(0.0f32); for (x, y) in xs.iter().zip(ys.iter()) {