From fa0d75b18d2eaa8662be52b88991bb5c87472a93 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 27 Aug 2023 08:17:38 +0100 Subject: [PATCH] Quantization tests + fix some issues. (#616) --- candle-core/src/quantized/k_quants.rs | 12 ++-- candle-core/tests/quantized_tests.rs | 93 +++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 6 deletions(-) diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index f94ff72f..177047b6 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -302,9 +302,9 @@ impl GgmlType for BlockQ4_1 { ys.d = f16::from_f32(d); ys.m = f16::from_f32(min); - for (j, q) in ys.qs.iter_mut().enumerate() { - let x0 = (xs[i * qk + j] - min) * id; - let x1 = (xs[i * qk + qk / 2 + j] - min) * id; + for (j, q) in ys.qs.iter_mut().take(qk / 2).enumerate() { + let x0 = (xs[j] - min) * id; + let x1 = (xs[qk / 2 + j] - min) * id; let xi0 = u8::min(15, (x0 + 0.5) as u8); let xi1 = u8::min(15, (x1 + 0.5) as u8); @@ -496,9 +496,9 @@ impl GgmlType for BlockQ5_1 { ys.m = f16::from_f32(min); let mut qh = 0u32; - for (j, q) in ys.qs.iter_mut().enumerate() { - let x0 = (xs[i * qk + j] - min) * id; - let x1 = (xs[i * qk + qk / 2 + j] - min) * id; + for (j, q) in ys.qs.iter_mut().take(qk / 2).enumerate() { + let x0 = (xs[j] - min) * id; + let x1 = (xs[qk / 2 + j] - min) * id; let xi0 = (x0 + 0.5) as u8; let xi1 = (x1 + 0.5) as u8; diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index a679e7b5..4f143492 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -135,7 +135,100 @@ fn quantize_q4_0() -> Result<()> { //mirrored GGML unit test ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + Ok(()) +} +#[test] +fn quantize_q4_1() -> Result<()> { + use k_quants::BlockQ4_1; + + let src = (0..32 * 4).map(|v| v as f32).collect::>(); + let mut dst = vec![0f32; 32 * 4]; + let mut quant = vec![BlockQ4_1::zeros(); 4]; + BlockQ4_1::from_float(&src, &mut quant)?; + BlockQ4_1::to_float(&quant, dst.as_mut_slice())?; + assert_eq!( + round_vector(&dst), + &[ + 0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332, + 12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73, + 22.73, 24.797, 24.797, 26.863, 26.863, 28.93, 28.93, 30.996, 30.996, 32.0, 32.0, + 34.066, 34.066, 36.133, 36.133, 38.199, 38.199, 40.266, 40.266, 42.332, 42.332, 44.398, + 44.398, 46.465, 46.465, 48.531, 48.531, 50.598, 50.598, 52.664, 52.664, 54.73, 54.73, + 56.797, 56.797, 58.863, 58.863, 60.93, 60.93, 62.996, 62.996, 64.0, 64.0, 66.066, + 66.066, 68.133, 68.133, 70.199, 70.199, 72.266, 72.266, 74.332, 74.332, 76.398, 76.398, + 78.465, 78.465, 80.531, 80.531, 82.598, 82.598, 84.664, 84.664, 86.73, 86.73, 88.797, + 88.797, 90.863, 90.863, 92.93, 92.93, 94.996, 94.996, 96.0, 96.0, 98.066, 98.066, + 100.133, 100.133, 102.199, 102.199, 104.266, 104.266, 106.332, 106.332, 108.398, + 108.398, 110.465, 110.465, 112.531, 112.531, 114.598, 114.598, 116.664, 116.664, + 118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996 + ] + ); + + //mirrored GGML unit test + ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + Ok(()) +} + +#[test] +fn quantize_q5_0() -> Result<()> { + use k_quants::BlockQ5_0; + + let src = (0..32 * 4).map(|v| v as f32).collect::>(); + let mut dst = vec![0f32; 32 * 4]; + let mut quant = vec![BlockQ5_0::zeros(); 4]; + BlockQ5_0::from_float(&src, &mut quant)?; + BlockQ5_0::to_float(&quant, dst.as_mut_slice())?; + assert_eq!( + round_vector(&dst), + &[ + -0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625, + 11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313, + 23.25, 23.25, 25.188, 25.188, 27.125, 27.125, 29.063, 29.063, 31.0, 31.5, 31.5, 35.438, + 35.438, 35.438, 35.438, 39.375, 39.375, 39.375, 39.375, 43.313, 43.313, 43.313, 43.313, + 47.25, 47.25, 47.25, 47.25, 51.188, 51.188, 51.188, 51.188, 55.125, 55.125, 55.125, + 55.125, 59.063, 59.063, 59.063, 59.063, 63.0, 63.0, 65.313, 65.313, 65.313, 65.313, + 65.313, 71.25, 71.25, 71.25, 71.25, 71.25, 71.25, 77.188, 77.188, 77.188, 77.188, + 77.188, 77.188, 83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 89.063, 89.063, 89.063, + 89.063, 89.063, 89.063, 95.0, 95.0, 95.0, 95.25, 95.25, 95.25, 95.25, 103.188, 103.188, + 103.188, 103.188, 103.188, 103.188, 103.188, 103.188, 111.125, 111.125, 111.125, + 111.125, 111.125, 111.125, 111.125, 111.125, 119.063, 119.063, 119.063, 119.063, + 119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0 + ] + ); + + //mirrored GGML unit test + ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + Ok(()) +} + +#[test] +fn quantize_q5_1() -> Result<()> { + use k_quants::BlockQ5_1; + + let src = (0..32 * 4).map(|v| v as f32).collect::>(); + let mut dst = vec![0f32; 32 * 4]; + let mut quant = vec![BlockQ5_1::zeros(); 4]; + BlockQ5_1::from_float(&src, &mut quant)?; + BlockQ5_1::to_float(&quant, dst.as_mut_slice())?; + assert_eq!( + dst, + &[ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, + 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, + 16.0, 16.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, + 44.0, 45.0, 46.0, 47.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, + 48.0, 48.0, 48.0, 48.0, 48.0, 48.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, + 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 80.0, 80.0, 80.0, 80.0, 80.0, 80.0, + 80.0, 80.0, 80.0, 80.0, 80.0, 80.0, 80.0, 80.0, 80.0, 80.0, 96.0, 97.0, 98.0, 99.0, + 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0, + 112.0, 112.0, 112.0, 112.0, 112.0, 112.0, 112.0, 112.0, 112.0, 112.0, 112.0, 112.0, + 112.0, 112.0, 112.0, 112.0 + ] + ); + + //mirrored GGML unit test + ggml_quantization_error_test::(0.014)?; Ok(()) }