diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs index 37a4e3ba..d4b05bb0 100644 --- a/candle-core/src/quantized/avx.rs +++ b/candle-core/src/quantized/avx.rs @@ -638,3 +638,35 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res Ok(hsum_float_8(acc) + summs) } } + +#[inline(always)] +pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result { + let qk = QK_K; + if n % qk != 0 { + crate::bail!("vec_dot_q8k_8k: {n} is not divisible by {qk}") + } + + unsafe { + let mut acc = _mm256_setzero_ps(); + for (xs, ys) in xs.iter().zip(ys.iter()) { + let mut sumi = _mm256_setzero_si256(); + let x_qs = xs.qs.as_ptr(); + let y_qs = ys.qs.as_ptr(); + for j in (0..QK_K).step_by(32) { + let xs = _mm256_loadu_si256(x_qs.add(j) as *const __m256i); + let ys = _mm256_loadu_si256(y_qs.add(j) as *const __m256i); + + let xs0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 0)); + let ys0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 0)); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs0, ys0)); + + let xs1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 1)); + let ys1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 1)); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs1, ys1)); + } + let d = _mm256_set1_ps(xs.d * ys.d); + acc = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi), acc); + } + Ok(hsum_float_8(acc)) + } +} diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 80d36555..7567c446 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1758,6 +1758,9 @@ impl GgmlType for BlockQ8K { #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "avx")] + return super::avx::vec_dot_q8k_q8k(n, xs, ys); + #[cfg(target_feature = "neon")] return super::neon::vec_dot_q8k_q8k(n, xs, ys); diff --git a/candle-wasm-tests/tests/quantized_tests.rs b/candle-wasm-tests/tests/quantized_tests.rs index 8beeab60..5d53728c 100644 --- a/candle-wasm-tests/tests/quantized_tests.rs +++ b/candle-wasm-tests/tests/quantized_tests.rs @@ -82,6 +82,9 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result { GgmlDType::Q5_0 => 0.001353, GgmlDType::Q5_1 => 0.001363, GgmlDType::Q8_0 => 0.000092, + + // Not from the ggml repo. + GgmlDType::Q8K => 0.00065, _ => candle::bail!("No GGML results for quantization type {dtype:?}",), }; Ok(err) @@ -181,3 +184,9 @@ fn quantized_matmul_q6k() -> Result<()> { ggml_matmul_error_test::()?; Ok(()) } + +#[wasm_bindgen_test] +fn quantized_matmul_q8k() -> Result<()> { + ggml_matmul_error_test::()?; + Ok(()) +}