mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
4 Commits
Author | SHA1 | Date | |
---|---|---|---|
c2261d0222 | |||
06d186355b | |||
2bbd544832 | |||
504d0b9ac7 |
@ -558,6 +558,26 @@ fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Res
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_small_tensors(
|
||||
m: usize,
|
||||
k: usize,
|
||||
n: usize,
|
||||
device: &Device,
|
||||
) -> Result<(Tensor, Tensor, Tensor)> {
|
||||
let lhs = (0..m * k)
|
||||
.map(|i| i as f32 / (m * k) as f32)
|
||||
.collect::<Vec<_>>();
|
||||
let rhs = (0..n * k)
|
||||
.map(|i| i as f32 / (n * k) as f32)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let lhs = Tensor::from_vec(lhs, (m, k), device)?;
|
||||
let rhs = Tensor::from_vec(rhs, (n, k), device)?;
|
||||
|
||||
let mm = lhs.matmul(&rhs.t()?)?;
|
||||
Ok((lhs, rhs, mm))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_mm() -> Result<()> {
|
||||
ggml_matmul_error_test::<k_quants::BlockQ4_0>()?;
|
||||
@ -623,20 +643,30 @@ fn quantized_matmul_q3k() -> Result<()> {
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
let (lhs, rhs, mm) = get_small_tensors(m, k, n, cpu)?;
|
||||
// assert_eq!(mm.dims(), [m, n]);
|
||||
// let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
// let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
// assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
let qmm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.029, 1.418, -0.314, 1.495]);
|
||||
let error: f32 = ((&mm - &qmm)?.abs()? / &mm.abs()?)?
|
||||
.sum_all()?
|
||||
.to_scalar()?;
|
||||
let error = error / (m * n) as f32;
|
||||
|
||||
// assert_eq!(qmm.dims(), [m, n]);
|
||||
// let dst = qmm.flatten_all()?.to_vec1::<f32>()?;
|
||||
// let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
// assert_eq!(dst, [1.029, 1.418, -0.314, 1.495]);
|
||||
|
||||
assert!(
|
||||
error < 0.01,
|
||||
"{error} is too big, shouldn't exceed a few percent. \nGot:{qmm}\nExpected:\n{mm} "
|
||||
);
|
||||
|
||||
ggml_matmul_error_test::<BlockQ3K>()?;
|
||||
|
||||
@ -649,20 +679,30 @@ fn quantized_matmul_q4k() -> Result<()> {
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
let (lhs, rhs, mm) = get_small_tensors(m, k, n, cpu)?;
|
||||
// assert_eq!(mm.dims(), [m, n]);
|
||||
// let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
// let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
// assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
let qmm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.125, 1.435, -0.201, 1.589]);
|
||||
let error: f32 = ((&mm - &qmm)?.abs()? / &mm.abs()?)?
|
||||
.sum_all()?
|
||||
.to_scalar()?;
|
||||
let error = error / (m * n) as f32;
|
||||
|
||||
assert!(
|
||||
error < 0.01,
|
||||
"{error} is too big, shouldn't exceed a few percent. \nGot:{qmm}\nExpected:\n{mm} "
|
||||
);
|
||||
|
||||
// assert_eq!(mm.dims(), [m, n]);
|
||||
// let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
// let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
// assert_eq!(dst, [1.125, 1.435, -0.201, 1.589]);
|
||||
|
||||
ggml_matmul_error_test::<BlockQ4K>()?;
|
||||
|
||||
|
Reference in New Issue
Block a user