mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the q8k vec-dot multiplication. (#1019)
This commit is contained in:
@ -1760,8 +1760,24 @@ impl GgmlType for BlockQ8K {
|
|||||||
Self::vec_dot_unopt(n, xs, ys)
|
Self::vec_dot_unopt(n, xs, ys)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_dot_unopt(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
unreachable!()
|
let qk = QK8_0;
|
||||||
|
if n % QK8_0 != 0 {
|
||||||
|
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generic implementation.
|
||||||
|
let mut sumf = 0f32;
|
||||||
|
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||||
|
let sum_i = xs
|
||||||
|
.qs
|
||||||
|
.iter()
|
||||||
|
.zip(ys.qs.iter())
|
||||||
|
.map(|(&x, &y)| x as i32 * y as i32)
|
||||||
|
.sum::<i32>();
|
||||||
|
sumf += sum_i as f32 * xs.d * ys.d
|
||||||
|
}
|
||||||
|
Ok(sumf)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||||
|
@ -491,6 +491,9 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
|
|||||||
GgmlDType::Q5_0 => 0.001353,
|
GgmlDType::Q5_0 => 0.001353,
|
||||||
GgmlDType::Q5_1 => 0.001363,
|
GgmlDType::Q5_1 => 0.001363,
|
||||||
GgmlDType::Q8_0 => 0.000092,
|
GgmlDType::Q8_0 => 0.000092,
|
||||||
|
|
||||||
|
// Not from the ggml repo.
|
||||||
|
GgmlDType::Q8K => 0.00065,
|
||||||
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
|
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
|
||||||
};
|
};
|
||||||
Ok(err)
|
Ok(err)
|
||||||
@ -692,3 +695,28 @@ fn quantized_matmul_q6k() -> Result<()> {
|
|||||||
ggml_matmul_error_test::<BlockQ6K>()?;
|
ggml_matmul_error_test::<BlockQ6K>()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn quantized_matmul_q8k() -> Result<()> {
|
||||||
|
use k_quants::BlockQ8K;
|
||||||
|
|
||||||
|
let cpu = &Device::Cpu;
|
||||||
|
let (m, k, n) = (11, 512, 21);
|
||||||
|
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||||
|
assert_eq!(mm.dims(), [m, n]);
|
||||||
|
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||||
|
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||||
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
|
let rhs = quantized::QTensor::quantize::<BlockQ8K>(&rhs)?;
|
||||||
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
|
assert_eq!(mm.dims(), [m, n]);
|
||||||
|
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||||
|
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||||
|
assert_eq!(dst, [1.266, 1.504, -0.204, 1.7]);
|
||||||
|
|
||||||
|
ggml_matmul_error_test::<BlockQ8K>()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user