mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Simpler repro for the neon optimization issue + bugfix (#1544)
* Simpler repro for the neon optimization issue. * Bugfix for q4k. * Improve the fix, share the dot-prod bit. * Clippy fixes. * Fix for q6k. * Also fix for q2k. * Use the new shared dotprod. * Add more testing.
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
use candle_core::{
|
||||
bail,
|
||||
quantized::{self, GgmlDType},
|
||||
test_utils::to_vec2_round,
|
||||
Device, Module, Result, Tensor,
|
||||
@ -265,7 +266,8 @@ fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
|
||||
/// Creates a vector similar to the ones used in GGML unit tests:
|
||||
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
|
||||
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
|
||||
(0..GGML_TEST_SIZE)
|
||||
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
|
||||
@ -284,14 +286,15 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
|
||||
sum / a.len() as f32
|
||||
}
|
||||
|
||||
/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
|
||||
/// Similar to the GGML quantization unit test:
|
||||
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
|
||||
fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
|
||||
let src = create_ggml_like_vector(0.0);
|
||||
let mut dst = vec![0.0; GGML_TEST_SIZE];
|
||||
let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
|
||||
let error = calculate_rmse(src.as_slice(), dst.as_slice());
|
||||
if error > max_error {
|
||||
candle_core::bail!(
|
||||
bail!(
|
||||
"Quantization error {} exceeds max error {}",
|
||||
error,
|
||||
max_error
|
||||
@ -487,54 +490,66 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
|
||||
GgmlDType::Q5K => 0.000740,
|
||||
GgmlDType::Q6K => 0.000952,
|
||||
GgmlDType::Q4_0 => 0.001143,
|
||||
GgmlDType::Q4_1 => 0.007784,
|
||||
GgmlDType::Q4_1 => 0.008,
|
||||
GgmlDType::Q5_0 => 0.001353,
|
||||
GgmlDType::Q5_1 => 0.001363,
|
||||
GgmlDType::Q5_1 => 0.00149,
|
||||
GgmlDType::Q8_0 => 0.000092,
|
||||
|
||||
// Not from the ggml repo.
|
||||
GgmlDType::Q8K => 0.00065,
|
||||
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
|
||||
_ => bail!("No GGML results for quantization type {dtype:?}",),
|
||||
};
|
||||
Ok(err)
|
||||
}
|
||||
|
||||
/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
|
||||
/// Similar to the GGML matmul unit test:
|
||||
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
|
||||
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
||||
let a = create_ggml_like_vector(0.0);
|
||||
let b = create_ggml_like_vector(1.0);
|
||||
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 1.0)?;
|
||||
// Another example that is more likely to trigger the overflow reported in #1526
|
||||
let a = (0..GGML_TEST_SIZE)
|
||||
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
|
||||
.collect::<Vec<_>>();
|
||||
let b = (0..GGML_TEST_SIZE)
|
||||
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
|
||||
.collect::<Vec<_>>();
|
||||
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 2.0)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Result<()> {
|
||||
let length = a.len();
|
||||
|
||||
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
|
||||
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
|
||||
T::from_float(&a, &mut a_quant)?;
|
||||
T::VecDotType::from_float(&b, &mut b_quant)?;
|
||||
T::from_float(a, &mut a_quant)?;
|
||||
T::VecDotType::from_float(b, &mut b_quant)?;
|
||||
|
||||
let result = T::vec_dot(length, &a_quant, &b_quant)?;
|
||||
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
|
||||
let reference_result = vec_dot_reference(&a, &b);
|
||||
let reference_result = vec_dot_reference(a, b);
|
||||
|
||||
if (result - result_unopt).abs() / length as f32 > 1e-6 {
|
||||
candle_core::bail!(
|
||||
bail!(
|
||||
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
|
||||
)
|
||||
}
|
||||
|
||||
let error = (result - reference_result).abs() / length as f32;
|
||||
|
||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
|
||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m;
|
||||
|
||||
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||
candle_core::bail!(
|
||||
"Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
|
||||
);
|
||||
bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",);
|
||||
}
|
||||
|
||||
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
|
||||
// => we use a slightly higher error threshold
|
||||
const ERROR_LENIENCY: f32 = 0.00001;
|
||||
if error - ERROR_LENIENCY > ggml_error {
|
||||
candle_core::bail!(
|
||||
bail!(
|
||||
"Dot product error {} exceeds ggml reference error {}",
|
||||
error,
|
||||
ggml_error
|
||||
@ -543,6 +558,16 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_mm() -> Result<()> {
|
||||
ggml_matmul_error_test::<k_quants::BlockQ4_0>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ4_1>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ5_0>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ5_1>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ8_0>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result.
|
||||
fn get_random_tensors(
|
||||
m: usize,
|
||||
|
Reference in New Issue
Block a user