From aa76b783eb9daac6d91cdc85a0526950bda1e524 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 3 Aug 2023 09:31:20 +0100 Subject: [PATCH] Q6K dequantization. (#315) --- candle-core/src/ggml.rs | 44 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/candle-core/src/ggml.rs b/candle-core/src/ggml.rs index 5af2acdb..f6418942 100644 --- a/candle-core/src/ggml.rs +++ b/candle-core/src/ggml.rs @@ -170,7 +170,7 @@ fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) { fn dequantize_row_q4k(xs: &[BlockQ4K], ys: &mut [f32]) -> Result<()> { let k = ys.len(); if k % QK_K != 0 { - crate::bail!("dequantize_row_q2k: {k} is not divisible by {QK_K}") + crate::bail!("dequantize_row_q4k: {k} is not divisible by {QK_K}") } let mut ys_index = 0; for x in xs.iter() { @@ -202,6 +202,39 @@ fn dequantize_row_q4k(xs: &[BlockQ4K], ys: &mut [f32]) -> Result<()> { Ok(()) } +// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067 +fn dequantize_row_q6k(xs: &[BlockQ6K], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK_K != 0 { + crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}") + } + for x in xs.iter() { + let d = x.d.to_f32(); + let ql = &x.ql; + let qh = &x.qh; + let sc = &x.scales; + for n in (0..QK_K).step_by(128) { + let idx = n / 128; + let ys = &mut ys[n..]; + let sc = &sc[8 * idx..]; + let ql = &ql[64 * idx..]; + let qh = &qh[32 * idx..]; + for l in 0..32 { + let is = l / 16; + let q1 = ((ql[l] & 0xF) | ((qh[l] & 3) << 4)) as i8 - 32; + let q2 = ((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i8 - 32; + let q3 = ((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i8 - 32; + let q4 = ((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i8 - 32; + ys[l] = d * sc[is] as f32 * q1 as f32; + ys[l + 32] = d * sc[is + 2] as f32 * q2 as f32; + ys[l + 64] = d * sc[is + 4] as f32 * q3 as f32; + ys[l + 96] = d * sc[is + 6] as f32 * q4 as f32; + } + } + } + Ok(()) +} + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37 #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum Magic { @@ -442,6 +475,15 @@ fn read_one_tensor( dequantize_row_q4k(raw_data, &mut f32_data)?; Tensor::from_vec(f32_data, dims, device)? } + GgmlDType::Q6K => { + let mut f32_data = vec![0f32; tensor_elems]; + let raw_data_ptr = raw_data.as_ptr(); + let n_blocks = size_in_bytes / std::mem::size_of::(); + let raw_data = + unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ6K, n_blocks) }; + dequantize_row_q6k(raw_data, &mut f32_data)?; + Tensor::from_vec(f32_data, dims, device)? + } _ => crate::bail!("quantized type {dtype:?} used in {name} is not supported yet"), }; Ok((name, tensor))