Files
candle/candle-wasm-tests/tests/quantized_tests.rs

177 lines
5.5 KiB
Rust

use candle::{
quantized::{self, k_quants, GgmlDType, GgmlType},
test_utils::to_vec2_round,
Device, Result, Tensor,
};
use wasm_bindgen_test::*;
wasm_bindgen_test_configure!(run_in_browser);
#[wasm_bindgen_test]
fn quantized_matmul_neg() -> Result<()> {
let cpu = &Device::Cpu;
let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k))
.map(|v| v as f32 - (m * k) as f32 / 2.0)
.collect::<Vec<_>>();
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..k * n)
.map(|v| v as f32 - (k * n) as f32 / 3.0)
.collect::<Vec<_>>();
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
assert_eq!(
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
&[
243524.0, -19596.0, -285051.0, -549815.0, 23777.0, 21651.0, 19398.0, 18367.0,
-196472.0, 63012.0, 324585.0, 587902.0
]
);
let mm = tensor_lhs.matmul(&tensor_rhs)?;
assert_eq!(
to_vec2_round(&mm, 0)?,
&[
[244064.0, -20128.0, -284320.0, -548512.0],
[23563.0, 21515.0, 19467.0, 17419.0],
[-196939.0, 63157.0, 323253.0, 583349.0]
]
);
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor);
let res = matmul.forward(&tensor_lhs)?;
assert_eq!(
to_vec2_round(&res, 0)?,
&[
[243524.0, -19596.0, -285051.0, -549815.0],
[23777.0, 21651.0, 19398.0, 18367.0],
[-196472.0, 63012.0, 324585.0, 587902.0]
]
);
Ok(())
}
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
const GGML_TEST_SIZE: usize = 32 * 128;
(0..GGML_TEST_SIZE)
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
.collect()
}
/// Very simple dot product implementation
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(a, b)| a * b).sum()
}
/// Returns the error achieved by the GGML matmul unit test.
fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
let err = match dtype {
GgmlDType::F16 => 0.000010,
GgmlDType::Q2K => 0.004086,
GgmlDType::Q3K => 0.016148,
GgmlDType::Q4K => 0.002425,
GgmlDType::Q5K => 0.000740,
GgmlDType::Q6K => 0.000952,
GgmlDType::Q4_0 => 0.001143,
GgmlDType::Q4_1 => 0.007784,
GgmlDType::Q5_0 => 0.001353,
GgmlDType::Q5_1 => 0.001363,
GgmlDType::Q8_0 => 0.000092,
_ => candle::bail!("No GGML results for quantization type {dtype:?}",),
};
Ok(err)
}
/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02;
let a = create_ggml_like_vector(0.0);
let b = create_ggml_like_vector(1.0);
let length = a.len();
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
T::from_float(&a, &mut a_quant)?;
T::VecDotType::from_float(&b, &mut b_quant)?;
let result = T::vec_dot(length, &a_quant, &b_quant)?;
let reference_result = vec_dot_reference(&a, &b);
let error = (result - reference_result).abs() / length as f32;
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
if error > GGML_MAX_DOT_PRODUCT_ERROR {
candle::bail!(
"Dot product error {} exceeds max error {}",
error,
GGML_MAX_DOT_PRODUCT_ERROR
);
}
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
// => we use a slightly higher error threshold
const ERROR_LENIENCY: f32 = 0.00001;
if error - ERROR_LENIENCY > ggml_error {
candle::bail!(
"Dot product error {} exceeds ggml reference error {}",
error,
ggml_error
);
}
Ok(())
}
#[wasm_bindgen_test]
fn quantized_matmul_q40() -> Result<()> {
ggml_matmul_error_test::<candle::quantized::k_quants::BlockQ4_0>()?;
Ok(())
}
#[wasm_bindgen_test]
fn quantized_matmul_q50() -> Result<()> {
ggml_matmul_error_test::<candle::quantized::k_quants::BlockQ5_0>()?;
Ok(())
}
#[wasm_bindgen_test]
fn quantized_matmul_q80() -> Result<()> {
ggml_matmul_error_test::<candle::quantized::k_quants::BlockQ8_0>()?;
Ok(())
}
#[wasm_bindgen_test]
fn quantized_matmul_q2k() -> Result<()> {
ggml_matmul_error_test::<candle::quantized::k_quants::BlockQ2K>()?;
Ok(())
}
#[wasm_bindgen_test]
fn quantized_matmul_q3k() -> Result<()> {
ggml_matmul_error_test::<candle::quantized::k_quants::BlockQ3K>()?;
Ok(())
}
#[wasm_bindgen_test]
fn quantized_matmul_q4k() -> Result<()> {
ggml_matmul_error_test::<candle::quantized::k_quants::BlockQ4K>()?;
Ok(())
}
#[wasm_bindgen_test]
fn quantized_matmul_q5k() -> Result<()> {
ggml_matmul_error_test::<candle::quantized::k_quants::BlockQ5K>()?;
Ok(())
}
#[wasm_bindgen_test]
fn quantized_matmul_q6k() -> Result<()> {
ggml_matmul_error_test::<candle::quantized::k_quants::BlockQ6K>()?;
Ok(())
}