diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 099d0079..ac3f7def 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1760,8 +1760,24 @@ impl GgmlType for BlockQ8K { Self::vec_dot_unopt(n, xs, ys) } - fn vec_dot_unopt(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result { - unreachable!() + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") + } + + // Generic implementation. + let mut sumf = 0f32; + for (xs, ys) in xs.iter().zip(ys.iter()) { + let sum_i = xs + .qs + .iter() + .zip(ys.qs.iter()) + .map(|(&x, &y)| x as i32 * y as i32) + .sum::(); + sumf += sum_i as f32 * xs.d * ys.d + } + Ok(sumf) } fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 06f1ee47..a2cecbc3 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -491,6 +491,9 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result { GgmlDType::Q5_0 => 0.001353, GgmlDType::Q5_1 => 0.001363, GgmlDType::Q8_0 => 0.000092, + + // Not from the ggml repo. + GgmlDType::Q8K => 0.00065, _ => candle_core::bail!("No GGML results for quantization type {dtype:?}",), }; Ok(err) @@ -692,3 +695,28 @@ fn quantized_matmul_q6k() -> Result<()> { ggml_matmul_error_test::()?; Ok(()) } + +#[test] +fn quantized_matmul_q8k() -> Result<()> { + use k_quants::BlockQ8K; + + let cpu = &Device::Cpu; + let (m, k, n) = (11, 512, 21); + let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?; + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); + + let rhs = quantized::QTensor::quantize::(&rhs)?; + let rhs = quantized::QMatMul::from_qtensor(rhs)?; + let mm = rhs.forward(&lhs)?; + + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.266, 1.504, -0.204, 1.7]); + + ggml_matmul_error_test::()?; + Ok(()) +}