Referenze implementations of q2k and q3k vec-dot functions (#580)

* add `q2k` vec-dot

* `q3k` vec-dot + quantization bugfix
This commit is contained in:
Lukas Kreussel
2023-08-24 13:35:54 +02:00
committed by GitHub
parent ca318a6ec7
commit d2f42ab086
2 changed files with 233 additions and 7 deletions

View File

@ -444,6 +444,60 @@ fn get_random_tensors(
Ok((lhs, rhs, mm))
}
#[test]
fn quantized_matmul_q2k() -> Result<()> {
use k_quants::BlockQ2K;
let cpu = &Device::Cpu;
let (m, k, n) = (11, 512, 21);
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [0.916, 0.422, 0.215, 1.668]);
//mirrored GGML unit test
ggml_matmul_error_test::<BlockQ2K>()?;
Ok(())
}
#[test]
fn quantized_matmul_q3k() -> Result<()> {
use k_quants::BlockQ3K;
let cpu = &Device::Cpu;
let (m, k, n) = (11, 512, 21);
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.029, 1.418, -0.314, 1.495]);
//mirrored GGML unit test
ggml_matmul_error_test::<BlockQ3K>()?;
Ok(())
}
#[test]
fn quantized_matmul_q6k() -> Result<()> {
use k_quants::BlockQ6K;