mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Q6K dequantization. (#315)
This commit is contained in:
@ -170,7 +170,7 @@ fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
|
|||||||
fn dequantize_row_q4k(xs: &[BlockQ4K], ys: &mut [f32]) -> Result<()> {
|
fn dequantize_row_q4k(xs: &[BlockQ4K], ys: &mut [f32]) -> Result<()> {
|
||||||
let k = ys.len();
|
let k = ys.len();
|
||||||
if k % QK_K != 0 {
|
if k % QK_K != 0 {
|
||||||
crate::bail!("dequantize_row_q2k: {k} is not divisible by {QK_K}")
|
crate::bail!("dequantize_row_q4k: {k} is not divisible by {QK_K}")
|
||||||
}
|
}
|
||||||
let mut ys_index = 0;
|
let mut ys_index = 0;
|
||||||
for x in xs.iter() {
|
for x in xs.iter() {
|
||||||
@ -202,6 +202,39 @@ fn dequantize_row_q4k(xs: &[BlockQ4K], ys: &mut [f32]) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067
|
||||||
|
fn dequantize_row_q6k(xs: &[BlockQ6K], ys: &mut [f32]) -> Result<()> {
|
||||||
|
let k = ys.len();
|
||||||
|
if k % QK_K != 0 {
|
||||||
|
crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}")
|
||||||
|
}
|
||||||
|
for x in xs.iter() {
|
||||||
|
let d = x.d.to_f32();
|
||||||
|
let ql = &x.ql;
|
||||||
|
let qh = &x.qh;
|
||||||
|
let sc = &x.scales;
|
||||||
|
for n in (0..QK_K).step_by(128) {
|
||||||
|
let idx = n / 128;
|
||||||
|
let ys = &mut ys[n..];
|
||||||
|
let sc = &sc[8 * idx..];
|
||||||
|
let ql = &ql[64 * idx..];
|
||||||
|
let qh = &qh[32 * idx..];
|
||||||
|
for l in 0..32 {
|
||||||
|
let is = l / 16;
|
||||||
|
let q1 = ((ql[l] & 0xF) | ((qh[l] & 3) << 4)) as i8 - 32;
|
||||||
|
let q2 = ((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i8 - 32;
|
||||||
|
let q3 = ((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i8 - 32;
|
||||||
|
let q4 = ((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i8 - 32;
|
||||||
|
ys[l] = d * sc[is] as f32 * q1 as f32;
|
||||||
|
ys[l + 32] = d * sc[is + 2] as f32 * q2 as f32;
|
||||||
|
ys[l + 64] = d * sc[is + 4] as f32 * q3 as f32;
|
||||||
|
ys[l + 96] = d * sc[is + 6] as f32 * q4 as f32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
|
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
enum Magic {
|
enum Magic {
|
||||||
@ -442,6 +475,15 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
|||||||
dequantize_row_q4k(raw_data, &mut f32_data)?;
|
dequantize_row_q4k(raw_data, &mut f32_data)?;
|
||||||
Tensor::from_vec(f32_data, dims, device)?
|
Tensor::from_vec(f32_data, dims, device)?
|
||||||
}
|
}
|
||||||
|
GgmlDType::Q6K => {
|
||||||
|
let mut f32_data = vec![0f32; tensor_elems];
|
||||||
|
let raw_data_ptr = raw_data.as_ptr();
|
||||||
|
let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ6K>();
|
||||||
|
let raw_data =
|
||||||
|
unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ6K, n_blocks) };
|
||||||
|
dequantize_row_q6k(raw_data, &mut f32_data)?;
|
||||||
|
Tensor::from_vec(f32_data, dims, device)?
|
||||||
|
}
|
||||||
_ => crate::bail!("quantized type {dtype:?} used in {name} is not supported yet"),
|
_ => crate::bail!("quantized type {dtype:?} used in {name} is not supported yet"),
|
||||||
};
|
};
|
||||||
Ok((name, tensor))
|
Ok((name, tensor))
|
||||||
|
Reference in New Issue
Block a user