From 098909de40b1478dfd6fba92f9907b8cd88984a6 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 16 Aug 2023 20:59:40 +0100 Subject: [PATCH] Add vecdot for q6k-q8k. (#476) * Add vecdot for q6k-q8k. * Add some testing for q8k. * Use QMatMul for the output layer. --- candle-core/src/quantized/k_quants.rs | 58 ++++++++++++++++++++++++++- candle-core/tests/quantized_tests.rs | 22 ++++++++++ candle-examples/examples/ggml/main.rs | 6 +-- 3 files changed, 80 insertions(+), 6 deletions(-) diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 28ac896e..2f622026 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -462,8 +462,62 @@ impl GgmlType for BlockQ6K { const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; - fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result { - todo!() + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}") + } + + let mut aux8 = [0i8; QK_K]; + let mut aux16 = [0i16; 8]; + let mut sums = [0f32; 8]; + let mut aux32 = [0f32; 8]; + + for (x, y) in xs.iter().zip(ys.iter()) { + let q4 = &x.ql; + let qh = &x.qh; + let q8 = &y.qs; + aux32.fill(0f32); + + for j in (0..QK_K).step_by(128) { + let aux8 = &mut aux8[j..]; + let q4 = &q4[j / 2..]; + let qh = &qh[j / 4..]; + for l in 0..32 { + aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8; + aux8[l + 32] = + (((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8; + aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8; + aux8[l + 96] = + (((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8; + } + } + + for (j, &scale) in x.scales.iter().enumerate() { + let scale = scale as f32; + let q8 = &q8[16 * j..]; + let aux8 = &aux8[16 * j..]; + for l in 0..8 { + aux16[l] = q8[l] as i16 * aux8[l] as i16; + } + for l in 0..8 { + aux32[l] += scale * aux16[l] as f32 + } + let q8 = &q8[8..]; + let aux8 = &aux8[8..]; + for l in 0..8 { + aux16[l] = q8[l] as i16 * aux8[l] as i16; + } + for l in 0..8 { + aux32[l] += scale * aux16[l] as f32 + } + } + + let d = x.d.to_f32() * y.d; + for (sum, &a) in sums.iter_mut().zip(aux32.iter()) { + *sum += a * d; + } + } + Ok(sums.iter().sum()) } fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> { diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 9c5168bf..b40a7fdb 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -123,3 +123,25 @@ fn quantize_q4_0() -> Result<()> { ); Ok(()) } + +#[test] +fn quantize_q8k() -> Result<()> { + use k_quants::BlockQ8K; + + let src = (0..256 * 4) + .map(|v| (v as f32 - 512.) / 1024.) + .collect::>(); + let mut dst = vec![0f32; 256 * 4]; + let mut quant = vec![BlockQ8K::zeros(); 4]; + BlockQ8K::from_float(&src, &mut quant)?; + BlockQ8K::to_float(&quant, dst.as_mut_slice())?; + assert_eq!( + [src[0], src[128], src[256], src[512], src[800], src[1023]], + [-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344] + ); + assert_eq!( + [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], + [-0.5, -0.375, -0.25, -0.0, 0.28070068, 0.49902344] + ); + Ok(()) +} diff --git a/candle-examples/examples/ggml/main.rs b/candle-examples/examples/ggml/main.rs index 54bc5f57..f42d6f0f 100644 --- a/candle-examples/examples/ggml/main.rs +++ b/candle-examples/examples/ggml/main.rs @@ -155,8 +155,7 @@ struct ModelWeights { tok_embeddings: Embedding, layers: Vec, norm: RmsNorm, - // TODO: Switch to using QMatMul instead of linear once we have support for Q6K/Q8K. - output: candle_nn::Linear, + output: QMatMul, masks: HashMap, span: tracing::Span, span_output: tracing::Span, @@ -197,7 +196,6 @@ impl ModelWeights { let tok_embeddings = tok_embeddings.dequantize(cpu)?; let norm = RmsNorm::new(ct.remove("norm.weight")?)?; let output = ct.remove("output.weight")?; - let output = candle_nn::Linear::new(output.dequantize(cpu)?, None); let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); for layer_idx in 0..ct.hparams.n_layer { let prefix = format!("layers.{layer_idx}"); @@ -239,7 +237,7 @@ impl ModelWeights { tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize), layers, norm, - output, + output: QMatMul::from_qtensor(output), masks: HashMap::new(), span, span_output,