Simpler repro for the neon optimization issue + bugfix (#1544)

* Simpler repro for the neon optimization issue.

* Bugfix for q4k.

* Improve the fix, share the dot-prod bit.

* Clippy fixes.

* Fix for q6k.

* Also fix for q2k.

* Use the new shared dotprod.

* Add more testing.
This commit is contained in:
Laurent Mazare
2024-01-07 20:21:49 +01:00
committed by GitHub
parent 89b5a06858
commit 0eb90ed783
2 changed files with 97 additions and 168 deletions

View File

@ -12,6 +12,14 @@ use core::arch::arm::*;
#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
#[inline(always)]
unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t {
// TODO: dotprod
let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))
}
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
@ -43,15 +51,8 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
let v1_0l = vld1q_s8(y0.qs.as_ptr());
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
// TODO: Support dotprod when it's available outside of nightly.
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
let pl0 = vdotq_s32(v0_0ls, v1_0l);
let ph0 = vdotq_s32(v0_0hs, v1_0h);
sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
@ -82,14 +83,8 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
let y0_0 = vld1q_s8(y0.qs.as_ptr());
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
// TODO dotprod once this is the intrinsics are.
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
let p0 = vdotq_s32(x0_0, y0_0);
let p1 = vdotq_s32(x0_1, y0_1);
sumv0 = vmlaq_n_f32(
sumv0,
@ -118,10 +113,7 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res
for i in (0..QK_K).step_by(16) {
let xs = vld1q_s8(xs.add(i));
let ys = vld1q_s8(ys.add(i));
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
let xy = vdotq_s32(xs, ys);
sum_i = vaddq_s32(sum_i, xy)
}
sumf += vaddvq_s32(sum_i) as f32 * scale
@ -191,30 +183,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
// TODO: dotprod
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
);
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
scale = scale.add(2);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
);
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
scale = scale.add(2);
let q8bytes = vld1q_s8_x4(q8);
@ -234,29 +212,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
// TODO: dotprod case.
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
);
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
scale = scale.add(2);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
);
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
scale = scale.add(2);
}
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
@ -333,28 +298,14 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
// TODO: dotprod
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
);
sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
let p0 = vdotq_s32(q5bytes_0, q8bytes.0);
let p1 = vdotq_s32(q5bytes_1, q8bytes.1);
sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32;
scales = scales.add(1);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)),
vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
);
sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
let p2 = vdotq_s32(q5bytes_2, q8bytes.2);
let p3 = vdotq_s32(q5bytes_3, q8bytes.3);
sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32;
scales = scales.add(1);
}
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
@ -417,22 +368,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
for j in 0..QK_K / 64 {
let q4bits = vld1q_u8_x2(q4);
q4 = q4.add(32);
// TODO: dotprod
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
let q4bytes = int8x16x2_t(
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
);
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
);
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
let p0 = vdotq_s32(q4bytes.0, q8bytes.0);
let p1 = vdotq_s32(q4bytes.1, q8bytes.1);
sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32;
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
@ -440,15 +384,9 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
);
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
let p2 = vdotq_s32(q4bytes.0, q8bytes.0);
let p3 = vdotq_s32(q4bytes.1, q8bytes.1);
sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32;
}
sumf += d * (sumi1 + sumi2) as f32;
}
@ -526,27 +464,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(q3h_3),
);
// TODO: dotprod
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)),
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)),
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)),
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
);
isum += vaddvq_s16(p0) as i32 * *scale as i32
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0);
let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1);
let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2);
let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3);
isum += vaddvq_s32(p0) * *scale as i32
+ vaddvq_s32(p1) * *scale.add(1) as i32
+ vaddvq_s32(p2) * *scale.add(2) as i32
+ vaddvq_s32(p3) * *scale.add(3) as i32;
scale = scale.add(4);
let q3h_0 = vbicq_u8(m2, qhbits.0);
@ -571,27 +496,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(q3h_3),
);
// TODO: dotprod
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)),
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)),
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)),
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
);
isum += vaddvq_s16(p0) as i32 * *scale as i32
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0);
let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1);
let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2);
let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3);
isum += vaddvq_s32(p0) * *scale as i32
+ vaddvq_s32(p1) * *scale.add(1) as i32
+ vaddvq_s32(p2) * *scale.add(2) as i32
+ vaddvq_s32(p3) * *scale.add(3) as i32;
scale = scale.add(4);
if j == 0 {
@ -649,7 +561,6 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res
let mut is = 0usize;
// TODO: dotprod
for _j in 0..QK_K / 128 {
let q2bits = vld1q_u8_x2(q2);
q2 = q2.add(32);
@ -696,14 +607,7 @@ unsafe fn multiply_accum_with_scale(
q2bytes: int8x16x2_t,
q8bytes: int8x16x2_t,
) -> i32 {
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)),
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
);
vaddvq_s16(p1) as i32 * aux[is + index] as i32
+ vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
let p1 = vdotq_s32(q2bytes.0, q8bytes.0);
let p2 = vdotq_s32(q2bytes.1, q8bytes.1);
vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32
}

View File

@ -1,4 +1,5 @@
use candle_core::{
bail,
quantized::{self, GgmlDType},
test_utils::to_vec2_round,
Device, Module, Result, Tensor,
@ -265,7 +266,8 @@ fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) {
}
}
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
/// Creates a vector similar to the ones used in GGML unit tests:
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
(0..GGML_TEST_SIZE)
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
@ -284,14 +286,15 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
sum / a.len() as f32
}
/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
/// Similar to the GGML quantization unit test:
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
let src = create_ggml_like_vector(0.0);
let mut dst = vec![0.0; GGML_TEST_SIZE];
let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
let error = calculate_rmse(src.as_slice(), dst.as_slice());
if error > max_error {
candle_core::bail!(
bail!(
"Quantization error {} exceeds max error {}",
error,
max_error
@ -487,54 +490,66 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
GgmlDType::Q5K => 0.000740,
GgmlDType::Q6K => 0.000952,
GgmlDType::Q4_0 => 0.001143,
GgmlDType::Q4_1 => 0.007784,
GgmlDType::Q4_1 => 0.008,
GgmlDType::Q5_0 => 0.001353,
GgmlDType::Q5_1 => 0.001363,
GgmlDType::Q5_1 => 0.00149,
GgmlDType::Q8_0 => 0.000092,
// Not from the ggml repo.
GgmlDType::Q8K => 0.00065,
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
_ => bail!("No GGML results for quantization type {dtype:?}",),
};
Ok(err)
}
/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
/// Similar to the GGML matmul unit test:
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
let a = create_ggml_like_vector(0.0);
let b = create_ggml_like_vector(1.0);
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 1.0)?;
// Another example that is more likely to trigger the overflow reported in #1526
let a = (0..GGML_TEST_SIZE)
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
.collect::<Vec<_>>();
let b = (0..GGML_TEST_SIZE)
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
.collect::<Vec<_>>();
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 2.0)?;
Ok(())
}
fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Result<()> {
let length = a.len();
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
T::from_float(&a, &mut a_quant)?;
T::VecDotType::from_float(&b, &mut b_quant)?;
T::from_float(a, &mut a_quant)?;
T::VecDotType::from_float(b, &mut b_quant)?;
let result = T::vec_dot(length, &a_quant, &b_quant)?;
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
let reference_result = vec_dot_reference(&a, &b);
let reference_result = vec_dot_reference(a, b);
if (result - result_unopt).abs() / length as f32 > 1e-6 {
candle_core::bail!(
bail!(
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
)
}
let error = (result - reference_result).abs() / length as f32;
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m;
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
candle_core::bail!(
"Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
);
bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",);
}
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
// => we use a slightly higher error threshold
const ERROR_LENIENCY: f32 = 0.00001;
if error - ERROR_LENIENCY > ggml_error {
candle_core::bail!(
bail!(
"Dot product error {} exceeds ggml reference error {}",
error,
ggml_error
@ -543,6 +558,16 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
Ok(())
}
#[test]
fn quantized_mm() -> Result<()> {
ggml_matmul_error_test::<k_quants::BlockQ4_0>()?;
ggml_matmul_error_test::<k_quants::BlockQ4_1>()?;
ggml_matmul_error_test::<k_quants::BlockQ5_0>()?;
ggml_matmul_error_test::<k_quants::BlockQ5_1>()?;
ggml_matmul_error_test::<k_quants::BlockQ8_0>()?;
Ok(())
}
/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result.
fn get_random_tensors(
m: usize,