mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Referenze implementations of q2k
and q3k
vec-dot functions (#580)
* add `q2k` vec-dot * `q3k` vec-dot + quantization bugfix
This commit is contained in:
@ -444,6 +444,60 @@ fn get_random_tensors(
|
||||
Ok((lhs, rhs, mm))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q2k() -> Result<()> {
|
||||
use k_quants::BlockQ2K;
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [0.916, 0.422, 0.215, 1.668]);
|
||||
|
||||
//mirrored GGML unit test
|
||||
ggml_matmul_error_test::<BlockQ2K>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q3k() -> Result<()> {
|
||||
use k_quants::BlockQ3K;
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.029, 1.418, -0.314, 1.495]);
|
||||
|
||||
//mirrored GGML unit test
|
||||
ggml_matmul_error_test::<BlockQ3K>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q6k() -> Result<()> {
|
||||
use k_quants::BlockQ6K;
|
||||
|
Reference in New Issue
Block a user