From 8b6f5be1cc16f5e50fa6137a9e9faf75e73c20c2 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 4 Aug 2023 09:51:30 +0100 Subject: [PATCH] Support q5k quantized data. (#320) --- candle-core/src/ggml.rs | 66 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/candle-core/src/ggml.rs b/candle-core/src/ggml.rs index 3a0aa5a1..4a5d4fa0 100644 --- a/candle-core/src/ggml.rs +++ b/candle-core/src/ggml.rs @@ -202,6 +202,54 @@ fn dequantize_row_q4k(xs: &[BlockQ4K], ys: &mut [f32]) -> Result<()> { Ok(()) } +// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533 +fn dequantize_row_q3k(_xs: &[BlockQ3K], _ys: &mut [f32]) -> Result<()> { + todo!() +} + +// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928 +fn dequantize_row_q5k(xs: &[BlockQ5K], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK_K != 0 { + crate::bail!("dequantize_row_q5k: {k} is not divisible by {QK_K}") + } + let mut ys_index = 0; + for x in xs.iter() { + let d = x.d.to_f32(); + let min = x.dmin.to_f32(); + let ql = &x.qs; + let qh = &x.qh; + let mut is = 0; + let mut u1 = 1; + let mut u2 = 2; + for j in (0..QK_K).step_by(64) { + let ql = &ql[j / 2..j / 2 + 32]; + let (sc, m) = get_scale_min_k4(is, &x.scales); + let d1 = d * sc as f32; + let m1 = min * m as f32; + let (sc, m) = get_scale_min_k4(is + 1, &x.scales); + let d2 = d * sc as f32; + let m2 = min * m as f32; + for (ql, qh) in ql.iter().zip(qh) { + let to_add = if qh & u1 != 0 { 16 } else { 1 }; + let y = d1 * ((ql & 0xF) + to_add) as f32 - m1; + ys[ys_index] = y; + ys_index += 1; + } + for (ql, qh) in ql.iter().zip(qh) { + let to_add = if qh & u2 != 0 { 16 } else { 1 }; + let y = d2 * ((ql >> 4) + to_add) as f32 - m2; + ys[ys_index] = y; + ys_index += 1; + } + is += 2; + u1 <<= 2; + u2 <<= 2; + } + } + Ok(()) +} + // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067 fn dequantize_row_q6k(xs: &[BlockQ6K], ys: &mut [f32]) -> Result<()> { let k = ys.len(); @@ -466,6 +514,15 @@ fn read_one_tensor( // Maybe we should use bf16 instead? Tensor::from_vec(f32_data, dims, device)? } + GgmlDType::Q3K => { + let mut f32_data = vec![0f32; tensor_elems]; + let raw_data_ptr = raw_data.as_ptr(); + let n_blocks = size_in_bytes / std::mem::size_of::(); + let raw_data = + unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ3K, n_blocks) }; + dequantize_row_q3k(raw_data, &mut f32_data)?; + Tensor::from_vec(f32_data, dims, device)? + } GgmlDType::Q4K => { let mut f32_data = vec![0f32; tensor_elems]; let raw_data_ptr = raw_data.as_ptr(); @@ -475,6 +532,15 @@ fn read_one_tensor( dequantize_row_q4k(raw_data, &mut f32_data)?; Tensor::from_vec(f32_data, dims, device)? } + GgmlDType::Q5K => { + let mut f32_data = vec![0f32; tensor_elems]; + let raw_data_ptr = raw_data.as_ptr(); + let n_blocks = size_in_bytes / std::mem::size_of::(); + let raw_data = + unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ5K, n_blocks) }; + dequantize_row_q5k(raw_data, &mut f32_data)?; + Tensor::from_vec(f32_data, dims, device)? + } GgmlDType::Q6K => { let mut f32_data = vec![0f32; tensor_elems]; let raw_data_ptr = raw_data.as_ptr();