diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 0fecf379..bfc471a3 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1,3 +1,7 @@ +use super::utils::{ + get_scale_min_k4, group_for_dequantization, group_for_quantization, make_q3_quants, + make_qkx1_quants, make_qx_quants, nearest_int, +}; use super::GgmlDType; use crate::Result; use half::f16; @@ -146,6 +150,104 @@ pub struct BlockQ8K { } const _: () = assert!(4 + QK_K + QK_K / 16 * 2 == std::mem::size_of::()); +impl GgmlType for BlockQ4_0 { + const DTYPE: GgmlDType = GgmlDType::Q4_0; + const BLCK_SIZE: usize = QK4_0; + type VecDotType = BlockQ8_0; + + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1525 + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + let qk = Self::BLCK_SIZE; + if k % qk != 0 { + crate::bail!("dequantize_row_q4_0: {k} is not divisible by {qk}") + } + + let nb = k / qk; + for i in 0..nb { + let d = xs[i].d.to_f32(); + + for j in 0..(qk / 2) { + let x0 = (xs[i].qs[j] & 0x0F) as i16 - 8; + let x1 = (xs[i].qs[j] >> 4) as i16 - 8; + + ys[i * qk + j] = (x0 as f32) * d; + ys[i * qk + j + qk / 2] = (x1 as f32) * d; + } + } + Ok(()) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + // quantize_row_q4_0 + let qk = Self::BLCK_SIZE; + let k = xs.len(); + if k % qk != 0 { + crate::bail!("{k} is not divisible by {}", qk); + }; + let nb = k / qk; + if ys.len() != nb { + crate::bail!("size mismatch {} {} {}", xs.len(), ys.len(), qk,) + } + for (i, ys) in ys.iter_mut().enumerate() { + let mut amax = 0f32; + let mut max = 0f32; + + let xs = &xs[i * qk..(i + 1) * qk]; + for &x in xs.iter() { + if amax < x.abs() { + amax = x.abs(); + max = x; + } + } + let d = max / -8.0; + let id = if d != 0f32 { 1. / d } else { 0. }; + ys.d = f16::from_f32(d); + + for (j, q) in ys.qs.iter_mut().enumerate() { + let x0 = xs[j] * id; + let x1 = xs[qk / 2 + j] * id; + let xi0 = u8::min(15, (x0 + 8.5) as u8); + let xi1 = u8::min(15, (x1 + 8.5) as u8); + *q = xi0 | (xi1 << 4) + } + } + Ok(()) + } + + // https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L2361C10-L2361C122 + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "avx")] + return super::avx::vec_dot_q4_0_q8_0(n, xs, ys); + + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q4_0_q8_0(n, xs, ys); + + let qk = QK8_0; + let nb = n / qk; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") + } + if nb % 2 != 0 { + crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even") + } + + // Generic implementation. + let mut sumf = 0f32; + for (xs, ys) in xs.iter().zip(ys.iter()) { + let mut sum_i = 0; + for j in 0..qk / 2 { + let v0 = (xs.qs[j] & 0x0F) as i32 - 8; + let v1 = (xs.qs[j] >> 4) as i32 - 8; + sum_i += v0 * ys.qs[j] as i32 + v1 * ys.qs[j + qk / 2] as i32 + } + sumf += sum_i as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d) + } + Ok(sumf) + } +} + impl GgmlType for BlockQ4_1 { const DTYPE: GgmlDType = GgmlDType::Q4_1; const BLCK_SIZE: usize = QK4_1; @@ -264,129 +366,70 @@ impl GgmlType for BlockQ5_1 { } } -impl GgmlType for BlockQ2K { - const DTYPE: GgmlDType = GgmlDType::Q2K; - const BLCK_SIZE: usize = QK_K; - type VecDotType = BlockQ8K; +impl GgmlType for BlockQ8_0 { + const DTYPE: GgmlDType = GgmlDType::Q8_0; + const BLCK_SIZE: usize = QK8_0; + type VecDotType = BlockQ8_0; - fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result { - todo!() - } - - fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> { - todo!() - } - // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354 + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1619 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { let k = ys.len(); - if k % QK_K != 0 { - crate::bail!("dequantize_row_q2k: {k} is not divisible by {QK_K}") + if k % QK8_0 != 0 { + crate::bail!("dequantize_row_q8_0: {k} is not divisible by {QK8_0}"); } - let mut ys_index = 0; - for x in xs { - let d = x.d.to_f32(); - let min = x.dmin.to_f32(); - let q = &x.qs; - let mut is = 0; - for n in (0..QK_K).step_by(128) { - // Step by 32 over q. - let q = &q[n / 4..]; - let mut shift = 0; - for _j in 0..4 { - let sc = x.scales[is]; - is += 1; - let dl = d * (sc & 0xF) as f32; - let ml = min * (sc >> 4) as f32; - for q in &q[..16] { - let y = dl * ((q >> shift) & 3) as i8 as f32 - ml; - ys[ys_index] = y; - ys_index += 1; - } + let nb = k / QK8_0; - let sc = x.scales[is]; - is += 1; - let dl = d * (sc & 0xF) as f32; - let ml = min * (sc >> 4) as f32; - for q in &q[16..32] { - let y = dl * ((q >> shift) & 3) as i8 as f32 - ml; - ys[ys_index] = y; - ys_index += 1; - } + for i in 0..nb { + let d = xs[i].d.to_f32(); - shift += 2; - } + for j in 0..QK8_0 { + ys[i * QK8_0 + j] = xs[i].qs[j] as f32 * d; } } Ok(()) } -} -fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) { - if j < 4 { - let d = q[j] & 63; - let m = q[j + 4] & 63; - (d, m) - } else { - let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); - let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4); - (d, m) - } -} - -impl GgmlType for BlockQ4K { - const DTYPE: GgmlDType = GgmlDType::Q4K; - const BLCK_SIZE: usize = QK_K; - type VecDotType = BlockQ8K; - - fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result { - todo!() - } - - fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> { - todo!() - } - // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735 - fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { - let k = ys.len(); - if k % QK_K != 0 { - crate::bail!("dequantize_row_q4k: {k} is not divisible by {QK_K}") + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + // quantize_row_q8_0 + let k = xs.len(); + if k % Self::BLCK_SIZE != 0 { + crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE); + }; + let nb = k / Self::BLCK_SIZE; + if ys.len() != nb { + crate::bail!( + "size mismatch {} {} {}", + xs.len(), + ys.len(), + Self::BLCK_SIZE + ) } - let mut ys_index = 0; - for x in xs.iter() { - let d = x.d.to_f32(); - let min = x.dmin.to_f32(); - let q = &x.qs; - let mut is = 0; - for j in (0..QK_K).step_by(64) { - let q = &q[j / 2..j / 2 + 32]; - let (sc, m) = get_scale_min_k4(is, &x.scales); - let d1 = d * sc as f32; - let m1 = min * m as f32; - let (sc, m) = get_scale_min_k4(is + 1, &x.scales); - let d2 = d * sc as f32; - let m2 = min * m as f32; - for q in q { - let y = d1 * (q & 0xF) as f32 - m1; - ys[ys_index] = y; - ys_index += 1; - } - for q in q { - let y = d2 * (q >> 4) as f32 - m2; - ys[ys_index] = y; - ys_index += 1; - } - is += 2; + for (i, ys) in ys.iter_mut().enumerate() { + let mut amax = 0f32; + let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE]; + for &x in xs.iter() { + amax = amax.max(x.abs()) + } + let d = amax / ((1 << 7) - 1) as f32; + let id = if d != 0f32 { 1. / d } else { 0. }; + ys.d = f16::from_f32(d); + for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) { + *y = f32::round(x * id) as i8 } } Ok(()) } + + fn vec_dot(_: usize, _: &[Self], _: &[Self::VecDotType]) -> Result { + todo!() + } } -impl GgmlType for BlockQ3K { +impl GgmlType for BlockQ8_1 { const DTYPE: GgmlDType = GgmlDType::Q3K; const BLCK_SIZE: usize = QK_K; - type VecDotType = BlockQ8K; + type VecDotType = BlockQ8_1; fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result { todo!() @@ -402,6 +445,378 @@ impl GgmlType for BlockQ3K { } } +impl GgmlType for BlockQ2K { + const DTYPE: GgmlDType = GgmlDType::Q2K; + const BLCK_SIZE: usize = QK_K; + type VecDotType = BlockQ8K; + + fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result { + todo!() + } + + // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L279 + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + const Q4SCALE: f32 = 15.0; + + for (block, x) in group_for_quantization(xs, ys)? { + //calculate scales and mins + let mut mins: [f32; QK_K / 16] = [0.0; QK_K / 16]; + let mut scales: [f32; QK_K / 16] = [0.0; QK_K / 16]; + + for (j, x_scale_slice) in x.chunks(16).enumerate() { + (scales[j], mins[j]) = make_qkx1_quants(3, 5, x_scale_slice); + } + // get max scale and max min and ensure they are >= 0.0 + let max_scale = scales.iter().fold(0.0, |max, &val| val.max(max)); + let max_min = mins.iter().fold(0.0, |max, &val| val.max(max)); + + if max_scale > 0.0 { + let iscale = Q4SCALE / max_scale; + for (j, scale) in scales.iter().enumerate().take(QK_K / 16) { + block.scales[j] = nearest_int(iscale * scale) as u8; + } + block.d = f16::from_f32(max_scale / Q4SCALE); + } else { + for j in 0..QK_K / 16 { + block.scales[j] = 0; + } + block.d = f16::from_f32(0.0); + } + + if max_min > 0.0 { + let iscale = Q4SCALE / max_min; + for (j, scale) in block.scales.iter_mut().enumerate() { + let l = nearest_int(iscale * mins[j]) as u8; + *scale |= l << 4; + } + block.dmin = f16::from_f32(max_min / Q4SCALE); + } else { + block.dmin = f16::from_f32(0.0); + } + + let mut big_l: [u8; QK_K] = [0; QK_K]; + + for j in 0..QK_K / 16 { + let d = block.d.to_f32() * (block.scales[j] & 0xF) as f32; + if d == 0.0 { + continue; + } + let dm = block.dmin.to_f32() * (block.scales[j] >> 4) as f32; + for ii in 0..16 { + let mut ll = nearest_int((x[16 * j + ii] + dm) / d); + ll = ll.clamp(0, 3); + big_l[16 * j + ii] = ll as u8; + } + } + + for j in (0..QK_K).step_by(128) { + for ll in 0..32 { + block.qs[j / 4 + ll] = big_l[j + ll] + | (big_l[j + ll + 32] << 2) + | (big_l[j + ll + 64] << 4) + | (big_l[j + ll + 96] << 6); + } + } + } + Ok(()) + } + // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354 + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + for (block, y) in group_for_dequantization(xs, ys)? { + let d = block.d.to_f32(); + let min = block.dmin.to_f32(); + + let mut is = 0; + + for (y_block, qs) in y.chunks_exact_mut(128).zip(block.qs.chunks_exact(32)) { + // Step by 32 over q. + let mut shift = 0; + let mut y_block_index = 0; + for _j in 0..4 { + let sc = block.scales[is]; + is += 1; + let dl = d * (sc & 0xF) as f32; + let ml = min * (sc >> 4) as f32; + for q in &qs[..16] { + let y = dl * ((q >> shift) & 3) as f32 - ml; + y_block[y_block_index] = y; + y_block_index += 1; + } + + let sc = block.scales[is]; + is += 1; + let dl = d * (sc & 0xF) as f32; + let ml = min * (sc >> 4) as f32; + for q in &qs[16..] { + let y = dl * ((q >> shift) & 3) as f32 - ml; + y_block[y_block_index] = y; + y_block_index += 1; + } + + shift += 2; + } + } + } + Ok(()) + } +} + +impl GgmlType for BlockQ3K { + const DTYPE: GgmlDType = GgmlDType::Q3K; + const BLCK_SIZE: usize = QK_K; + type VecDotType = BlockQ8K; + + fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result { + todo!() + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + for (block, x) in group_for_quantization(xs, ys)? { + let mut scales: [f32; QK_K / 16] = [0.0; QK_K / 16]; + for (j, x_scale_slice) in x.chunks_exact(16).enumerate() { + scales[j] = make_q3_quants(x_scale_slice, 4, true); + } + + // Get max scale by absolute value. + let max_scale = scales + .iter() + .fold(0.0, |max, &val| if val.abs() > max { val } else { max }); + + block.scales.fill(0); + + if max_scale != 0.0 { + let iscale = -32.0 / max_scale; + for (j, scale) in scales.iter().enumerate() { + let mut l_val = nearest_int(iscale * scale); + l_val = l_val.clamp(-32, 31) + 32; + if j < 8 { + block.scales[j] = (l_val & 0xF) as u8; + } else { + block.scales[j - 8] |= ((l_val & 0xF) << 4) as u8; + } + l_val >>= 4; + block.scales[j % 4 + 8] |= (l_val << (2 * (j / 4))) as u8; + } + block.d = f16::from_f32(1.0 / iscale); + } else { + block.d = f16::from_f32(0.0); + } + + let mut l: [i8; QK_K] = [0; QK_K]; + + for j in 0..QK_K / 16 { + let sc = if j < 8 { + block.scales[j] & 0xF + } else { + block.scales[j - 8] >> 4 + }; + let sc = (sc | (((block.scales[8 + j % 4] >> (2 * (j / 4))) & 3) << 4)) as i8 - 32; + let d = block.d.to_f32() * sc as f32; + if d != 0.0 { + for ii in 0..16 { + let mut l_val = nearest_int(x[16 * j + ii] / d); + l_val = l_val.clamp(-4, 3); + l[16 * j + ii] = (l_val + 4) as i8; + } + } + } + + block.hmask.fill(0); + let mut m = 0; + let mut hm = 1; + + for ll in l.iter_mut() { + if *ll > 3 { + block.hmask[m] |= hm; + *ll -= 4; + } + m += 1; + if m == QK_K / 8 { + m = 0; + hm <<= 1; + } + } + + for j in (0..QK_K).step_by(128) { + for l_val in 0..32 { + block.qs[j / 4 + l_val] = (l[j + l_val] + | (l[j + l_val + 32] << 2) + | (l[j + l_val + 64] << 4) + | (l[j + l_val + 96] << 6)) + as u8; + } + } + } + + Ok(()) + } + + // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533 + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + const KMASK1: u32 = 0x03030303; + const KMASK2: u32 = 0x0f0f0f0f; + + for (block, y) in group_for_dequantization(xs, ys)? { + //Reconstruct the scales + let mut aux = [0; 4]; + let aux_raw = unsafe { + std::mem::transmute::<&mut [u8; 12], &mut [u32; 3]>(&mut block.scales.clone()) + }; + aux[0..3].copy_from_slice(aux_raw); + + let tmp = aux[2]; + aux[2] = ((aux[0] >> 4) & KMASK2) | (((tmp >> 4) & KMASK1) << 4); + aux[3] = ((aux[1] >> 4) & KMASK2) | (((tmp >> 6) & KMASK1) << 4); + aux[0] = (aux[0] & KMASK2) | (((tmp) & KMASK1) << 4); + aux[1] = (aux[1] & KMASK2) | (((tmp >> 2) & KMASK1) << 4); + + //Transfer the scales into an i8 array + let scales: &mut [i8] = + unsafe { std::slice::from_raw_parts_mut(aux.as_mut_ptr() as *mut i8, 16) }; + + let d_all = block.d.to_f32(); + let mut m = 1; + let mut is = 0; + let mut dl; + + // Dequantize both 128 long blocks + // 32 qs values per 128 long block + // Each 16 elements get a scale + for (y, qs) in y.chunks_exact_mut(128).zip(block.qs.chunks_exact(32)) { + let mut shift = 0; + for shift_scoped_y in y.chunks_exact_mut(32) { + for (scale_index, scale_scoped_y) in + shift_scoped_y.chunks_exact_mut(16).enumerate() + { + dl = d_all * (scales[is] as f32 - 32.0); + for (i, inner_y) in scale_scoped_y.iter_mut().enumerate() { + let new_y = dl + * (((qs[i + 16 * scale_index] >> shift) & 3) as i8 + - if (block.hmask[i + 16 * scale_index] & m) == 0 { + 4 + } else { + 0 + }) as f32; + *inner_y = new_y; + } + // 16 block finished => advance scale index + is += 1; + } + //32 block finished => increase shift and m + shift += 2; + m <<= 1; + } + } + } + + Ok(()) + } +} + +impl GgmlType for BlockQ4K { + const DTYPE: GgmlDType = GgmlDType::Q4K; + const BLCK_SIZE: usize = QK_K; + type VecDotType = BlockQ8K; + + fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result { + todo!() + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + for (block, x) in group_for_quantization(xs, ys)? { + let mut mins: [f32; QK_K / 32] = [0.0; QK_K / 32]; + let mut scales: [f32; QK_K / 32] = [0.0; QK_K / 32]; + + for (j, x_scale_slice) in x.chunks_exact(32).enumerate() { + (scales[j], mins[j]) = make_qkx1_quants(15, 5, x_scale_slice); + } + + // get max scale and max min and ensure they are >= 0.0 + let max_scale = scales.iter().fold(0.0, |max, &val| val.max(max)); + let max_min = mins.iter().fold(0.0, |max, &val| val.max(max)); + + let inv_scale = if max_scale > 0.0 { + 63.0 / max_scale + } else { + 0.0 + }; + let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 }; + + for j in 0..QK_K / 32 { + let mut ls = nearest_int(inv_scale * scales[j]) as u8; + let mut lm = nearest_int(inv_min * mins[j]) as u8; + ls = std::cmp::min(63, ls); + lm = std::cmp::min(63, lm); + if j < 4 { + block.scales[j] = ls; + block.scales[j + 4] = lm; + } else { + block.scales[j + 4] = (ls & 0xF) | ((lm & 0xF) << 4); + block.scales[j - 4] |= (ls >> 4) << 6; + block.scales[j] |= (lm >> 4) << 6; + } + } + + block.d = f16::from_f32(max_scale / 63.0); + block.dmin = f16::from_f32(max_min / 63.0); + + let mut l: [u8; QK_K] = [0; QK_K]; + + for j in 0..QK_K / 32 { + let (sc, m) = get_scale_min_k4(j, &block.scales); + let d = block.d.to_f32() * sc as f32; + if d != 0.0 { + let dm = block.dmin.to_f32() * m as f32; + for ii in 0..32 { + let mut l_val = nearest_int((x[32 * j + ii] + dm) / d); + l_val = l_val.clamp(0, 15); + l[32 * j + ii] = l_val as u8; + } + } + } + + let q = &mut block.qs; + for j in (0..QK_K).step_by(64) { + for l_val in 0..32 { + let offset_index = (j / 64) * 32 + l_val; + q[offset_index] = l[j + l_val] | (l[j + l_val + 32] << 4); + } + } + } + Ok(()) + } + // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735 + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + for (block, y) in group_for_dequantization(xs, ys)? { + let d = block.d.to_f32(); + let min = block.dmin.to_f32(); + let q = &block.qs; + let mut is = 0; + let mut ys_index = 0; + + for j in (0..QK_K).step_by(64) { + let q = &q[j / 2..j / 2 + 32]; + let (sc, m) = get_scale_min_k4(is, &block.scales); + let d1 = d * sc as f32; + let m1 = min * m as f32; + let (sc, m) = get_scale_min_k4(is + 1, &block.scales); + let d2 = d * sc as f32; + let m2 = min * m as f32; + for q in q { + y[ys_index] = d1 * (q & 0xF) as f32 - m1; + ys_index += 1; + } + for q in q { + y[ys_index] = d2 * (q >> 4) as f32 - m2; + ys_index += 1; + } + is += 2; + } + } + Ok(()) + } +} + // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928 impl GgmlType for BlockQ5K { const DTYPE: GgmlDType = GgmlDType::Q5K; @@ -412,41 +827,115 @@ impl GgmlType for BlockQ5K { todo!() } - fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> { - todo!() - } - fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { - let k = ys.len(); - if k % QK_K != 0 { - crate::bail!("dequantize_row_q5k: {k} is not divisible by {QK_K}") + // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L793 + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + for (block, x) in group_for_quantization(xs, ys)? { + let mut mins: [f32; QK_K / 32] = [0.0; QK_K / 32]; + let mut scales: [f32; QK_K / 32] = [0.0; QK_K / 32]; + + for (j, x_scale_slice) in x.chunks_exact(32).enumerate() { + (scales[j], mins[j]) = make_qkx1_quants(31, 5, x_scale_slice); + } + + // get max scale and max min and ensure they are >= 0.0 + let max_scale = scales.iter().fold(0.0, |max, &val| val.max(max)); + let max_min = mins.iter().fold(0.0, |max, &val| val.max(max)); + + let inv_scale = if max_scale > 0.0 { + 63.0 / max_scale + } else { + 0.0 + }; + let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 }; + for j in 0..QK_K / 32 { + let mut ls = nearest_int(inv_scale * scales[j]) as u8; + let mut lm = nearest_int(inv_min * mins[j]) as u8; + ls = ls.min(63); + lm = lm.min(63); + if j < 4 { + block.scales[j] = ls; + block.scales[j + 4] = lm; + } else { + block.scales[j + 4] = (ls & 0xF) | ((lm & 0xF) << 4); + block.scales[j - 4] |= (ls >> 4) << 6; + block.scales[j] |= (lm >> 4) << 6; + } + } + block.d = f16::from_f32(max_scale / 63.0); + block.dmin = f16::from_f32(max_min / 63.0); + + let mut l: [u8; QK_K] = [0; QK_K]; + for j in 0..QK_K / 32 { + let (sc, m) = get_scale_min_k4(j, &block.scales); + let d = block.d.to_f32() * sc as f32; + if d == 0.0 { + continue; + } + let dm = block.dmin.to_f32() * m as f32; + for ii in 0..32 { + let mut ll = nearest_int((x[32 * j + ii] + dm) / d); + ll = ll.min(31).max(0); + l[32 * j + ii] = ll as u8; + } + } + + let qh = &mut block.qh; + let ql = &mut block.qs; + qh.fill(0); + + let mut m1 = 1; + let mut m2 = 2; + for n in (0..QK_K).step_by(64) { + let offset = (n / 64) * 32; + for j in 0..32 { + let mut l1 = l[n + j]; + if l1 > 15 { + l1 -= 16; + qh[j] |= m1; + } + let mut l2 = l[n + j + 32]; + if l2 > 15 { + l2 -= 16; + qh[j] |= m2; + } + ql[offset + j] = l1 | (l2 << 4); + } + m1 <<= 2; + m2 <<= 2; + } } - let mut ys_index = 0; - for x in xs.iter() { - let d = x.d.to_f32(); - let min = x.dmin.to_f32(); - let ql = &x.qs; - let qh = &x.qh; + + Ok(()) + } + + // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928 + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + for (block, y) in group_for_dequantization(xs, ys)? { + let d = block.d.to_f32(); + let min = block.dmin.to_f32(); + let ql = &block.qs; + let qh = &block.qh; let mut is = 0; let mut u1 = 1; let mut u2 = 2; + let mut ys_index = 0; + for j in (0..QK_K).step_by(64) { let ql = &ql[j / 2..j / 2 + 32]; - let (sc, m) = get_scale_min_k4(is, &x.scales); + let (sc, m) = get_scale_min_k4(is, &block.scales); let d1 = d * sc as f32; let m1 = min * m as f32; - let (sc, m) = get_scale_min_k4(is + 1, &x.scales); + let (sc, m) = get_scale_min_k4(is + 1, &block.scales); let d2 = d * sc as f32; let m2 = min * m as f32; for (ql, qh) in ql.iter().zip(qh) { let to_add = if qh & u1 != 0 { 16 } else { 1 }; - let y = d1 * ((ql & 0xF) + to_add) as f32 - m1; - ys[ys_index] = y; + y[ys_index] = d1 * ((ql & 0xF) + to_add) as f32 - m1; ys_index += 1; } for (ql, qh) in ql.iter().zip(qh) { let to_add = if qh & u2 != 0 { 16 } else { 1 }; - let y = d2 * ((ql >> 4) + to_add) as f32 - m2; - ys[ys_index] = y; + y[ys_index] = d2 * ((ql >> 4) + to_add) as f32 - m2; ys_index += 1; } is += 2; @@ -458,143 +947,6 @@ impl GgmlType for BlockQ5K { } } -fn nearest_int(v: f32) -> i32 { - v.round() as i32 -} - -unsafe fn make_qx_quants(n: usize, nmax: i32, x: *const f32, ls: *mut i8, rmse_type: i32) -> f32 { - let mut max = 0f32; - let mut amax = 0f32; - for i in 0..n { - let x = *x.add(i); - let ax = x.abs(); - if ax > amax { - amax = ax; - max = x; - } - } - if amax == 0. { - // all zero - for i in 0..n { - *ls.add(i) = 0; - } - return 0.; - } - let mut iscale = -(nmax as f32) / max; - if rmse_type == 0 { - for i in 0..n { - let x = *x.add(i); - let l = nearest_int(iscale * x); - *ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8; - } - return 1.0 / iscale; - } - let weight_type = rmse_type % 2; - let mut sumlx = 0f32; - let mut suml2 = 0f32; - for i in 0..n { - let x = *x.add(i); - let l = nearest_int(iscale * x); - let l = l.clamp(-nmax, nmax - 1); - *ls.add(i) = (l + nmax) as i8; - let w = if weight_type == 1 { x * x } else { 1.0 }; - let l = l as f32; - sumlx += w * x * l; - suml2 += w * l * l; - } - let mut scale = sumlx / suml2; - let mut best = scale * sumlx; - for _itry in 0..3 { - let iscale = 1.0 / scale; - let mut slx = 0f32; - let mut sl2 = 0f32; - let mut changed = false; - for i in 0..n { - let x = *x.add(i); - let l = nearest_int(iscale * x); - let l = l.clamp(-nmax, nmax - 1); - if l + nmax != *ls.add(i) as i32 { - changed = true; - } - let w = if weight_type == 1 { x * x } else { 1f32 }; - let l = l as f32; - slx += w * x * l; - sl2 += w * l * l; - } - if !changed || sl2 == 0.0 || slx * slx <= best * sl2 { - break; - } - for i in 0..n { - let x = *x.add(i); - let l = nearest_int(iscale * x); - *ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8; - } - sumlx = slx; - suml2 = sl2; - scale = sumlx / suml2; - best = scale * sumlx; - } - for _itry in 0..5 { - let mut n_changed = 0; - for i in 0..n { - let x = *x.add(i); - let w = if weight_type == 1 { x * x } else { 1. }; - let l = *ls.add(i) as i32 - nmax; - let mut slx = sumlx - w * x * l as f32; - if slx > 0. { - let mut sl2 = suml2 - w * l as f32 * l as f32; - let new_l = nearest_int(x * sl2 / slx); - let new_l = new_l.clamp(-nmax, nmax - 1); - if new_l != l { - slx += w * x * new_l as f32; - sl2 += w * new_l as f32 * new_l as f32; - if sl2 > 0. && slx * slx * suml2 > sumlx * sumlx * sl2 { - *ls.add(i) = (nmax + new_l) as i8; - sumlx = slx; - suml2 = sl2; - scale = sumlx / suml2; - best = scale * sumlx; - n_changed += 1; - } - } - } - } - if n_changed == 0 { - break; - } - } - if rmse_type < 3 { - return scale; - } - for is in -4..4 { - if is == 0 { - continue; - } - iscale = -(nmax as f32 + 0.1f32 * is as f32) / max; - let mut sumlx = 0.; - let mut suml2 = 0.; - for i in 0..n { - let x = *x.add(i); - let l = nearest_int(iscale * x); - let l = i32::max(-nmax, i32::min(nmax - 1, l)); - let w = if weight_type == 1 { x * x } else { 1. }; - let l = l as f32; - sumlx += w * x * l; - suml2 += w * l * l; - } - if suml2 > 0. && sumlx * sumlx > best * suml2 { - for i in 0..n { - let x = *x.add(i); - let l = nearest_int(iscale * x); - *ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8; - } - scale = sumlx / suml2; - best = scale * sumlx; - } - } - scale -} - impl GgmlType for BlockQ6K { const DTYPE: GgmlDType = GgmlDType::Q6K; const BLCK_SIZE: usize = QK_K; @@ -833,183 +1185,6 @@ impl GgmlType for BlockQ8K { } } -impl GgmlType for BlockQ4_0 { - const DTYPE: GgmlDType = GgmlDType::Q4_0; - const BLCK_SIZE: usize = QK4_0; - type VecDotType = BlockQ8_0; - - // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1525 - fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { - let k = ys.len(); - let qk = Self::BLCK_SIZE; - if k % qk != 0 { - crate::bail!("dequantize_row_q4_0: {k} is not divisible by {qk}") - } - - let nb = k / qk; - for i in 0..nb { - let d = xs[i].d.to_f32(); - - for j in 0..(qk / 2) { - let x0 = (xs[i].qs[j] & 0x0F) as i16 - 8; - let x1 = (xs[i].qs[j] >> 4) as i16 - 8; - - ys[i * qk + j] = (x0 as f32) * d; - ys[i * qk + j + qk / 2] = (x1 as f32) * d; - } - } - Ok(()) - } - - fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { - // quantize_row_q4_0 - let qk = Self::BLCK_SIZE; - let k = xs.len(); - if k % qk != 0 { - crate::bail!("{k} is not divisible by {}", qk); - }; - let nb = k / qk; - if ys.len() != nb { - crate::bail!("size mismatch {} {} {}", xs.len(), ys.len(), qk,) - } - for (i, ys) in ys.iter_mut().enumerate() { - let mut amax = 0f32; - let mut max = 0f32; - - let xs = &xs[i * qk..(i + 1) * qk]; - for &x in xs.iter() { - if amax < x.abs() { - amax = x.abs(); - max = x; - } - } - let d = max / -8.0; - let id = if d != 0f32 { 1. / d } else { 0. }; - ys.d = f16::from_f32(d); - - for (j, q) in ys.qs.iter_mut().enumerate() { - let x0 = xs[j] * id; - let x1 = xs[qk / 2 + j] * id; - let xi0 = u8::min(15, (x0 + 8.5) as u8); - let xi1 = u8::min(15, (x1 + 8.5) as u8); - *q = xi0 | (xi1 << 4) - } - } - Ok(()) - } - - // https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L2361C10-L2361C122 - #[allow(unreachable_code)] - fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - #[cfg(target_feature = "avx")] - return super::avx::vec_dot_q4_0_q8_0(n, xs, ys); - - #[cfg(target_feature = "neon")] - return super::neon::vec_dot_q4_0_q8_0(n, xs, ys); - - let qk = QK8_0; - let nb = n / qk; - if n % QK8_0 != 0 { - crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") - } - if nb % 2 != 0 { - crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even") - } - - // Generic implementation. - let mut sumf = 0f32; - for (xs, ys) in xs.iter().zip(ys.iter()) { - let mut sum_i = 0; - for j in 0..qk / 2 { - let v0 = (xs.qs[j] & 0x0F) as i32 - 8; - let v1 = (xs.qs[j] >> 4) as i32 - 8; - sum_i += v0 * ys.qs[j] as i32 + v1 * ys.qs[j + qk / 2] as i32 - } - sumf += sum_i as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d) - } - Ok(sumf) - } -} - -impl GgmlType for BlockQ8_0 { - const DTYPE: GgmlDType = GgmlDType::Q8_0; - const BLCK_SIZE: usize = QK8_0; - type VecDotType = BlockQ8_0; - - // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1619 - fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { - let k = ys.len(); - if k % QK8_0 != 0 { - crate::bail!("dequantize_row_q8_0: {k} is not divisible by {QK8_0}"); - } - - let nb = k / QK8_0; - - for i in 0..nb { - let d = xs[i].d.to_f32(); - - for j in 0..QK8_0 { - ys[i * QK8_0 + j] = xs[i].qs[j] as f32 * d; - } - } - Ok(()) - } - - fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { - // quantize_row_q8_0 - let k = xs.len(); - if k % Self::BLCK_SIZE != 0 { - crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE); - }; - let nb = k / Self::BLCK_SIZE; - if ys.len() != nb { - crate::bail!( - "size mismatch {} {} {}", - xs.len(), - ys.len(), - Self::BLCK_SIZE - ) - } - for (i, ys) in ys.iter_mut().enumerate() { - let mut amax = 0f32; - let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE]; - for &x in xs.iter() { - amax = amax.max(x.abs()) - } - let d = amax / ((1 << 7) - 1) as f32; - let id = if d != 0f32 { 1. / d } else { 0. }; - ys.d = f16::from_f32(d); - for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) { - *y = f32::round(x * id) as i8 - } - } - Ok(()) - } - - fn vec_dot(_: usize, _: &[Self], _: &[Self::VecDotType]) -> Result { - todo!() - } -} - -impl GgmlType for BlockQ8_1 { - const DTYPE: GgmlDType = GgmlDType::Q3K; - const BLCK_SIZE: usize = QK_K; - type VecDotType = BlockQ8_1; - - fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result { - todo!() - } - - fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> { - todo!() - } - - // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533 - fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> { - todo!() - } -} - // https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605 pub fn matmul( mkn: (usize, usize, usize), diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 41661633..f2c78689 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -6,6 +6,7 @@ pub mod ggml_file; pub mod k_quants; #[cfg(target_feature = "neon")] pub mod neon; +pub mod utils; pub use k_quants::GgmlType; diff --git a/candle-core/src/quantized/utils.rs b/candle-core/src/quantized/utils.rs new file mode 100644 index 00000000..fded9d61 --- /dev/null +++ b/candle-core/src/quantized/utils.rs @@ -0,0 +1,326 @@ +use crate::Result; + +pub(super) fn nearest_int(v: f32) -> i32 { + v.round() as i32 +} + +/// Validates that the input and output are the right size and returns an iterator which maps each input region `xs` to its corresponding output block in `ys`. Each output region is guaranteed to be `T::BLCK_SIZE` long. +pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>( + xs: &'b [f32], + ys: &'a mut [T], +) -> Result> { + let block_size = T::BLCK_SIZE; + let dtype = T::DTYPE; + + let expected_blocks = xs.len() / block_size; + let actual_blocks = ys.len(); + + //validate that the input is the right size + if expected_blocks != actual_blocks { + crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!") + } + + Ok(ys.iter_mut().zip(xs.chunks_exact(block_size)).collect()) +} + +/// Validates that the input and output are the right size and returns an iterator which maps each input block `xs` to its corresponding output region in `ys`. Each output region is guaranteed to be `T::BLCK_SIZE` long. +pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>( + xs: &'a [T], + ys: &'b mut [f32], +) -> Result> { + let block_size = T::BLCK_SIZE; + let dtype = T::DTYPE; + + let actual_output_len = ys.len(); + let expected_output_len = xs.len() * block_size; + //validate that the output is the right size + if expected_output_len != actual_output_len { + crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!") + } + + //zip the blocks and outputs together + Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect()) +} + +pub(super) fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) { + if j < 4 { + let d = q[j] & 63; + let m = q[j + 4] & 63; + (d, m) + } else { + let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4); + (d, m) + } +} + +pub(super) unsafe fn make_qx_quants( + n: usize, + nmax: i32, + x: *const f32, + ls: *mut i8, + rmse_type: i32, +) -> f32 { + let mut max = 0f32; + let mut amax = 0f32; + for i in 0..n { + let x = *x.add(i); + let ax = x.abs(); + if ax > amax { + amax = ax; + max = x; + } + } + if amax == 0. { + // all zero + for i in 0..n { + *ls.add(i) = 0; + } + return 0.; + } + let mut iscale = -(nmax as f32) / max; + if rmse_type == 0 { + for i in 0..n { + let x = *x.add(i); + let l = nearest_int(iscale * x); + *ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8; + } + return 1.0 / iscale; + } + let weight_type = rmse_type % 2; + let mut sumlx = 0f32; + let mut suml2 = 0f32; + for i in 0..n { + let x = *x.add(i); + let l = nearest_int(iscale * x); + let l = l.clamp(-nmax, nmax - 1); + *ls.add(i) = (l + nmax) as i8; + let w = if weight_type == 1 { x * x } else { 1.0 }; + let l = l as f32; + sumlx += w * x * l; + suml2 += w * l * l; + } + let mut scale = sumlx / suml2; + let mut best = scale * sumlx; + for _itry in 0..3 { + let iscale = 1.0 / scale; + let mut slx = 0f32; + let mut sl2 = 0f32; + let mut changed = false; + for i in 0..n { + let x = *x.add(i); + let l = nearest_int(iscale * x); + let l = l.clamp(-nmax, nmax - 1); + if l + nmax != *ls.add(i) as i32 { + changed = true; + } + let w = if weight_type == 1 { x * x } else { 1f32 }; + let l = l as f32; + slx += w * x * l; + sl2 += w * l * l; + } + if !changed || sl2 == 0.0 || slx * slx <= best * sl2 { + break; + } + for i in 0..n { + let x = *x.add(i); + let l = nearest_int(iscale * x); + *ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8; + } + sumlx = slx; + suml2 = sl2; + scale = sumlx / suml2; + best = scale * sumlx; + } + for _itry in 0..5 { + let mut n_changed = 0; + for i in 0..n { + let x = *x.add(i); + let w = if weight_type == 1 { x * x } else { 1. }; + let l = *ls.add(i) as i32 - nmax; + let mut slx = sumlx - w * x * l as f32; + if slx > 0. { + let mut sl2 = suml2 - w * l as f32 * l as f32; + let new_l = nearest_int(x * sl2 / slx); + let new_l = new_l.clamp(-nmax, nmax - 1); + if new_l != l { + slx += w * x * new_l as f32; + sl2 += w * new_l as f32 * new_l as f32; + if sl2 > 0. && slx * slx * suml2 > sumlx * sumlx * sl2 { + *ls.add(i) = (nmax + new_l) as i8; + sumlx = slx; + suml2 = sl2; + scale = sumlx / suml2; + best = scale * sumlx; + n_changed += 1; + } + } + } + } + if n_changed == 0 { + break; + } + } + if rmse_type < 3 { + return scale; + } + for is in -4..4 { + if is == 0 { + continue; + } + iscale = -(nmax as f32 + 0.1f32 * is as f32) / max; + let mut sumlx = 0.; + let mut suml2 = 0.; + for i in 0..n { + let x = *x.add(i); + let l = nearest_int(iscale * x); + let l = i32::max(-nmax, i32::min(nmax - 1, l)); + let w = if weight_type == 1 { x * x } else { 1. }; + let l = l as f32; + sumlx += w * x * l; + suml2 += w * l * l; + } + if suml2 > 0. && sumlx * sumlx > best * suml2 { + for i in 0..n { + let x = *x.add(i); + let l = nearest_int(iscale * x); + *ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8; + } + scale = sumlx / suml2; + best = scale * sumlx; + } + } + scale +} + +// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L224 +pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32) { + let n = x.len(); + let mut l = vec![0; n]; + // Get min/max + let mut min = *x + .iter() + .take(n) + .min_by(|a, b| a.total_cmp(b)) + .unwrap_or(&x[0]); + let max = *x.iter().max_by(|a, b| a.total_cmp(b)).unwrap_or(&x[0]); + + // If min == max, all values are the same => nothing to do here + if max == min { + return (0.0, 0.0); + } + + // Ensure min <= 0.0 + if min > 0.0 { + min = 0.0; + } + + // Compute scale and inverse scale + let mut iscale = nmax as f32 / (max - min); + let mut scale = 1.0 / iscale; + + for _ in 0..ntry { + let mut sumlx = 0.0; + let mut suml2 = 0; + let mut did_change = false; + + for (i, value) in x.iter().enumerate().take(n) { + let mut li = nearest_int(iscale * (value - min)); + li = li.clamp(0, nmax); + let clamped_li = li as u8; + if clamped_li != l[i] { + l[i] = clamped_li; + did_change = true; + } + sumlx += (value - min) * li as f32; + suml2 += li * li; + } + scale = sumlx / suml2 as f32; + + let sum: f32 = x + .iter() + .take(n) + .zip(l.iter().take(n)) + .map(|(xi, &li)| xi - scale * li as f32) + .sum(); + + min = sum / n as f32; + if min > 0.0 { + min = 0.0; + } + iscale = 1.0 / scale; + if !did_change { + break; + } + } + (scale, -min) +} + +// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L165 +pub(super) fn make_q3_quants(x: &[f32], nmax: i32, do_rmse: bool) -> f32 { + let n = x.len(); + let mut l = vec![0i8; n]; + + let mut max = 0.0; + let mut amax = 0.0; + for &xi in x.iter().take(n) { + let ax = xi.abs(); + if ax > amax { + amax = ax; + max = xi; + } + } + + if amax == 0.0 { + return 0.0; + } + + let iscale = -(nmax as f32) / max; + if do_rmse { + let mut sumlx = 0.0; + let mut suml2 = 0.0; + for i in 0..n { + let mut li = (iscale * x[i]).round() as i32; + li = li.clamp(-nmax, nmax - 1); + l[i] = li as i8; + let w = x[i] * x[i]; + sumlx += w * x[i] * li as f32; + suml2 += w * (li * li) as f32; + } + for _ in 0..5 { + let mut n_changed = 0; + for i in 0..n { + let w = x[i] * x[i]; + let mut slx = sumlx - w * x[i] * l[i] as f32; + if slx > 0.0 { + let mut sl2 = suml2 - w * (l[i] as i32 * l[i] as i32) as f32; + let mut new_l = (x[i] * sl2 / slx).round() as i32; + new_l = new_l.clamp(-nmax, nmax - 1); + if new_l != l[i] as i32 { + slx += w * x[i] * new_l as f32; + sl2 += w * (new_l * new_l) as f32; + if sl2 > 0.0 && slx * slx * suml2 > sumlx * sumlx * sl2 { + l[i] = new_l as i8; + sumlx = slx; + suml2 = sl2; + n_changed += 1; + } + } + } + } + if n_changed == 0 { + break; + } + } + for li in l.iter_mut() { + *li += nmax as i8; + } + return sumlx / suml2; + } + for i in 0..n { + let mut li = (iscale * x[i]).round() as i32; + li = li.clamp(-nmax, nmax - 1); + l[i] = (li + nmax) as i8; + } + 1.0 / iscale +} diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index ffb79c5b..47a2a25e 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -124,25 +124,152 @@ fn quantize_q4_0() -> Result<()> { Ok(()) } -#[test] -fn quantize_q8k() -> Result<()> { - use k_quants::BlockQ8K; +/// Generates a small test vector ranging from -`bound` to `bound` with `size` steps +fn get_test_vector(bound: f32, size: Option) -> (Vec, Vec) { + let size = size.unwrap_or(1024); + assert!( + size % crate::quantized::k_quants::QK_K == 0, + "size must be a multiple of {}", + crate::quantized::k_quants::QK_K + ); - let src = (0..256 * 4) - .map(|v| (v as f32 - 512.) / 1024.) + let src = (0..size) + .map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.)) .collect::>(); - let mut dst = vec![0f32; 256 * 4]; - let mut quant = vec![BlockQ8K::zeros(); 4]; - BlockQ8K::from_float(&src, &mut quant)?; - BlockQ8K::to_float(&quant, dst.as_mut_slice())?; + + let dst = vec![0f32; size]; + assert_eq!([src[0], src[size / 2]], [-bound, 0.0]); + (src, dst) +} + +/// Round a vector +fn round_vector(values: &[f32]) -> Vec { + values + .iter() + .map(|x| (1000. * x).round() / 1000.) + .collect::>() +} + +fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) { + for (i, (value, expected_value)) in values.iter().zip(expected.iter()).enumerate() { + let difference = (value - expected_value).abs(); + + assert!( + difference < tolerance, + "Error at index {}: value = {}, expected = {}. Difference = {} exceeds tolerance = {}.", + i, + value, + expected_value, + difference, + tolerance + ); + } +} + +fn quantize_roundtrip(src: &[f32], dst: &mut [f32]) -> Result> { + let mut quant = vec![T::zeros(); src.len() / T::BLCK_SIZE]; + T::from_float(src, &mut quant)?; + T::to_float(&quant, dst)?; + Ok(quant) +} + +#[test] +fn quantize_q2k() -> Result<()> { + use k_quants::BlockQ2K; + + let (src, mut dst) = get_test_vector(0.5, Some(1024)); + let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.1); + + // Test some specific values assert_eq!( [src[0], src[128], src[256], src[512], src[800], src[1023]], [-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344] ); + let dst = round_vector(&dst); assert_eq!( [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], - [-0.5, -0.375, -0.25, -0.0, 0.28070068, 0.49902344] + [-0.499, -0.366, -0.249, 0.0, 0.295, 0.492] ); + + let (src_big, mut dst_big) = get_test_vector(128.0, Some(1024)); + let _quant_big = quantize_roundtrip::(src_big.as_slice(), dst_big.as_mut_slice())?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0); + Ok(()) +} + +#[test] +fn quantize_q3k() -> Result<()> { + use k_quants::BlockQ3K; + + let (src, mut dst) = get_test_vector(0.5, Some(1024)); + let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.03); + + // Test some specific values + assert_eq!( + [src[0], src[128], src[256], src[512], src[800], src[1023]], + [-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344] + ); + let dst = round_vector(&dst); + assert_eq!( + [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], + [-0.493, -0.37, -0.243, -0.0, 0.292, 0.492] + ); + + let (src_big, mut dst_big) = get_test_vector(128.0, Some(1024)); + let _quant_big = quantize_roundtrip::(src_big.as_slice(), dst_big.as_mut_slice())?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5); + Ok(()) +} + +#[test] +fn quantize_q4k() -> Result<()> { + use k_quants::BlockQ4K; + + let (src, mut dst) = get_test_vector(0.5, Some(1024)); + let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.017); + + // Test some specific values + assert_eq!( + [src[0], src[128], src[256], src[512], src[800], src[1023]], + [-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344] + ); + let dst = round_vector(&dst); + assert_eq!( + [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], + [-0.5, -0.373, -0.25, 0.0, 0.288, 0.498] + ); + + let (src_big, mut dst_big) = get_test_vector(128.0, Some(1024)); + let _quant_big = quantize_roundtrip::(src_big.as_slice(), dst_big.as_mut_slice())?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5); + Ok(()) +} + +#[test] +fn quantize_q5k() -> Result<()> { + use k_quants::BlockQ5K; + + let (src, mut dst) = get_test_vector(0.5, Some(1024)); + let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.008); + + // Test some specific values + assert_eq!( + [src[0], src[128], src[256], src[512], src[800], src[1023]], + [-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344] + ); + let dst = round_vector(&dst); + assert_eq!( + [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], + [-0.499, -0.372, -0.249, 0.001, 0.279, 0.499] + ); + + let (src_big, mut dst_big) = get_test_vector(128.0, Some(1024)); + let _quant_big = quantize_roundtrip::(src_big.as_slice(), dst_big.as_mut_slice())?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5); Ok(()) } @@ -150,25 +277,51 @@ fn quantize_q8k() -> Result<()> { fn quantize_q6k() -> Result<()> { use k_quants::BlockQ6K; - let src = (0..256 * 4) - .map(|v| (v as f32 - 512.) / 1024.) - .collect::>(); - let mut dst = vec![0f32; 256 * 4]; - let mut quant = vec![BlockQ6K::zeros(); 4]; - BlockQ6K::from_float(&src, &mut quant)?; - BlockQ6K::to_float(&quant, dst.as_mut_slice())?; + let (src, mut dst) = get_test_vector(0.5, Some(1024)); + let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.008); + + // Test some specific values assert_eq!( [src[0], src[128], src[256], src[512], src[800], src[1023]], [-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344] ); - let dst = dst - .iter() - .map(|x| (1000. * x).round() / 1000.) - .collect::>(); + let dst = round_vector(&dst); assert_eq!( [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], [-0.497, -0.372, -0.25, -0.0, 0.284, 0.5] ); + + let (src_big, mut dst_big) = get_test_vector(128.0, Some(1024)); + let _quant_big = quantize_roundtrip::(src_big.as_slice(), dst_big.as_mut_slice())?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0); + + Ok(()) +} + +#[test] +fn quantize_q8k() -> Result<()> { + use k_quants::BlockQ8K; + + let (src, mut dst) = get_test_vector(0.5, Some(1024)); + let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.003); + + // Test some specific values + assert_eq!( + [src[0], src[128], src[256], src[512], src[800], src[1023]], + [-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344] + ); + let dst = round_vector(&dst); + assert_eq!( + [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], + [-0.5, -0.375, -0.25, -0.0, 0.281, 0.499] + ); + + let (src_big, mut dst_big) = get_test_vector(128.0, Some(1024)); + let _quant_big = quantize_roundtrip::(src_big.as_slice(), dst_big.as_mut_slice())?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6); + Ok(()) }