mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
More missing quantized bits. (#615)
* Q4_1 support. * Add Q5_1 quantization. * Tweak.
This commit is contained in:
@ -254,12 +254,65 @@ impl GgmlType for BlockQ4_1 {
|
||||
const BLCK_SIZE: usize = QK4_1;
|
||||
type VecDotType = BlockQ8_1;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
// ggml_vec_dot_q4_1_q8_1
|
||||
let qk = QK8_1;
|
||||
if n % qk != 0 {
|
||||
crate::bail!("vec_dot_q4_1_q8_1: {n} is not divisible by {qk}")
|
||||
}
|
||||
let nb = n / qk;
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_1_q8_1: {n}, nb is not divisible by 2")
|
||||
}
|
||||
|
||||
// Generic implementation.
|
||||
let mut sumf = 0f32;
|
||||
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
let mut sumi = 0i32;
|
||||
|
||||
for j in 0..qk / 2 {
|
||||
let v0 = xs.qs[j] as i32 & 0x0F;
|
||||
let v1 = xs.qs[j] as i32 >> 4;
|
||||
sumi += (v0 * ys.qs[j] as i32) + (v1 * ys.qs[j + qk / 2] as i32);
|
||||
}
|
||||
|
||||
sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||
// quantize_row_q4_1
|
||||
let qk = Self::BLCK_SIZE;
|
||||
if ys.len() * qk != xs.len() {
|
||||
crate::bail!("size mismatch {} {} {}", xs.len(), ys.len(), qk,)
|
||||
}
|
||||
for (i, ys) in ys.iter_mut().enumerate() {
|
||||
let xs = &xs[i * qk..(i + 1) * qk];
|
||||
|
||||
let mut min = f32::INFINITY;
|
||||
let mut max = f32::NEG_INFINITY;
|
||||
for &x in xs.iter() {
|
||||
min = f32::min(x, min);
|
||||
max = f32::max(x, max);
|
||||
}
|
||||
let d = (max - min) / ((1 << 4) - 1) as f32;
|
||||
let id = if d != 0f32 { 1. / d } else { 0. };
|
||||
ys.d = f16::from_f32(d);
|
||||
ys.m = f16::from_f32(min);
|
||||
|
||||
for (j, q) in ys.qs.iter_mut().enumerate() {
|
||||
let x0 = (xs[i * qk + j] - min) * id;
|
||||
let x1 = (xs[i * qk + qk / 2 + j] - min) * id;
|
||||
|
||||
let xi0 = u8::min(15, (x0 + 0.5) as u8);
|
||||
let xi1 = u8::min(15, (x1 + 0.5) as u8);
|
||||
|
||||
*q = xi0 | (xi1 << 4);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1545
|
||||
@ -422,8 +475,42 @@ impl GgmlType for BlockQ5_1 {
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||
// quantize_row_q5_1
|
||||
let qk = Self::BLCK_SIZE;
|
||||
if ys.len() * qk != xs.len() {
|
||||
crate::bail!("size mismatch {} {} {}", xs.len(), ys.len(), qk,)
|
||||
}
|
||||
for (i, ys) in ys.iter_mut().enumerate() {
|
||||
let xs = &xs[i * qk..(i + 1) * qk];
|
||||
|
||||
let mut min = f32::INFINITY;
|
||||
let mut max = f32::NEG_INFINITY;
|
||||
for &x in xs.iter() {
|
||||
min = f32::min(x, min);
|
||||
max = f32::max(x, max);
|
||||
}
|
||||
let d = (max - min) / ((1 << 5) - 1) as f32;
|
||||
let id = if d != 0f32 { 1. / d } else { 0. };
|
||||
ys.d = f16::from_f32(d);
|
||||
ys.m = f16::from_f32(min);
|
||||
|
||||
let mut qh = 0u32;
|
||||
for (j, q) in ys.qs.iter_mut().enumerate() {
|
||||
let x0 = (xs[i * qk + j] - min) * id;
|
||||
let x1 = (xs[i * qk + qk / 2 + j] - min) * id;
|
||||
|
||||
let xi0 = (x0 + 0.5) as u8;
|
||||
let xi1 = (x1 + 0.5) as u8;
|
||||
|
||||
*q = (xi0 & 0x0F) | ((xi1 & 0x0F0) << 4);
|
||||
// get the 5-th bit and store it in qh at the right position
|
||||
qh |= ((xi0 as u32 & 0x10) >> 4) << j;
|
||||
qh |= ((xi1 as u32 & 0x10) >> 4) << (j + qk / 2);
|
||||
}
|
||||
LittleEndian::write_u32(&mut ys.qh, qh);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1592
|
||||
@ -767,7 +854,7 @@ impl GgmlType for BlockQ3K {
|
||||
let hmask: &[u8] = &x.hmask;
|
||||
let mut q8: &[i8] = &y.qs;
|
||||
|
||||
aux32.iter_mut().for_each(|x| *x = 0);
|
||||
aux32.fill(0);
|
||||
let mut a = &mut aux8[..];
|
||||
|
||||
let mut m = 1;
|
||||
|
Reference in New Issue
Block a user