Q6K quantization (#495)

* Print the detected arch options.

* Add the q6k quantization.

* Add a currently broken test.

* Bugfix.

* Bugfix.

* Another bugfix.

* Another bugfix + get the test to work.
This commit is contained in:
Laurent Mazare
2023-08-17 22:22:57 +01:00
committed by GitHub
parent fc81af1712
commit 557b2c28dd
4 changed files with 257 additions and 2 deletions

View File

@ -457,6 +457,143 @@ impl GgmlType for BlockQ5K {
} }
} }
fn nearest_int(v: f32) -> i32 {
v.round() as i32
}
unsafe fn make_qx_quants(n: usize, nmax: i32, x: *const f32, ls: *mut i8, rmse_type: i32) -> f32 {
let mut max = 0f32;
let mut amax = 0f32;
for i in 0..n {
let x = *x.add(i);
let ax = x.abs();
if ax > amax {
amax = ax;
max = x;
}
}
if amax == 0. {
// all zero
for i in 0..n {
*ls.add(i) = 0;
}
return 0.;
}
let mut iscale = -(nmax as f32) / max;
if rmse_type == 0 {
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
}
return 1.0 / iscale;
}
let weight_type = rmse_type % 2;
let mut sumlx = 0f32;
let mut suml2 = 0f32;
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
let l = l.clamp(-nmax, nmax - 1);
*ls.add(i) = (l + nmax) as i8;
let w = if weight_type == 1 { x * x } else { 1.0 };
let l = l as f32;
sumlx += w * x * l;
suml2 += w * l * l;
}
let mut scale = sumlx / suml2;
let mut best = scale * sumlx;
for _itry in 0..3 {
let iscale = 1.0 / scale;
let mut slx = 0f32;
let mut sl2 = 0f32;
let mut changed = false;
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
let l = l.clamp(-nmax, nmax - 1);
if l + nmax != *ls.add(i) as i32 {
changed = true;
}
let w = if weight_type == 1 { x * x } else { 1f32 };
let l = l as f32;
slx += w * x * l;
sl2 += w * l * l;
}
if !changed || sl2 == 0.0 || slx * slx <= best * sl2 {
break;
}
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
}
sumlx = slx;
suml2 = sl2;
scale = sumlx / suml2;
best = scale * sumlx;
}
for _itry in 0..5 {
let mut n_changed = 0;
for i in 0..n {
let x = *x.add(i);
let w = if weight_type == 1 { x * x } else { 1. };
let l = *ls.add(i) as i32 - nmax;
let mut slx = sumlx - w * x * l as f32;
if slx > 0. {
let mut sl2 = suml2 - w * l as f32 * l as f32;
let new_l = nearest_int(x * sl2 / slx);
let new_l = new_l.clamp(-nmax, nmax - 1);
if new_l != l {
slx += w * x * new_l as f32;
sl2 += w * new_l as f32 * new_l as f32;
if sl2 > 0. && slx * slx * suml2 > sumlx * sumlx * sl2 {
*ls.add(i) = (nmax + new_l) as i8;
sumlx = slx;
suml2 = sl2;
scale = sumlx / suml2;
best = scale * sumlx;
n_changed += 1;
}
}
}
}
if n_changed == 0 {
break;
}
}
if rmse_type < 3 {
return scale;
}
for is in -4..4 {
if is == 0 {
continue;
}
iscale = -(nmax as f32 + 0.1f32 * is as f32) / max;
let mut sumlx = 0.;
let mut suml2 = 0.;
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
let l = i32::max(-nmax, i32::min(nmax - 1, l));
let w = if weight_type == 1 { x * x } else { 1. };
let l = l as f32;
sumlx += w * x * l;
suml2 += w * l * l;
}
if suml2 > 0. && sumlx * sumlx > best * suml2 {
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
}
scale = sumlx / suml2;
best = scale * sumlx;
}
}
scale
}
impl GgmlType for BlockQ6K { impl GgmlType for BlockQ6K {
const DTYPE: GgmlDType = GgmlDType::Q6K; const DTYPE: GgmlDType = GgmlDType::Q6K;
const BLCK_SIZE: usize = QK_K; const BLCK_SIZE: usize = QK_K;
@ -524,8 +661,76 @@ impl GgmlType for BlockQ6K {
Ok(sums.iter().sum()) Ok(sums.iter().sum())
} }
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> { fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
todo!() if xs.len() != ys.len() * Self::BLCK_SIZE {
crate::bail!(
"quantize_row_q6k: size mismatch {} {} {}",
xs.len(),
ys.len(),
Self::BLCK_SIZE
)
}
let mut l = [0i8; QK_K];
let mut scales = [0f32; QK_K / 16];
let mut x = xs.as_ptr();
let l = l.as_mut_ptr();
unsafe {
for y in ys.iter_mut() {
let mut max_scale = 0f32;
let mut max_abs_scale = 0f32;
for (ib, scale_) in scales.iter_mut().enumerate() {
let scale = make_qx_quants(16, 32, x.add(16 * ib), l.add(16 * ib), 1);
*scale_ = scale;
let abs_scale = scale.abs();
if abs_scale > max_abs_scale {
max_abs_scale = abs_scale;
max_scale = scale
}
}
let iscale = -128f32 / max_scale;
y.d = f16::from_f32(1.0 / iscale);
for (y_scale, scale) in y.scales.iter_mut().zip(scales.iter()) {
*y_scale = nearest_int(iscale * scale).min(127) as i8
}
for (j, &y_scale) in y.scales.iter().enumerate() {
let d = y.d.to_f32() * y_scale as f32;
if d == 0. {
continue;
}
for ii in 0..16 {
let ll = nearest_int(*x.add(16 * j + ii) / d).clamp(-32, 31);
*l.add(16 * j + ii) = (ll + 32) as i8
}
}
let mut ql = y.ql.as_mut_ptr();
let mut qh = y.qh.as_mut_ptr();
for j in (0..QK_K).step_by(128) {
for l_idx in 0..32 {
let q1 = *l.add(j + l_idx) & 0xF;
let q2 = *l.add(j + l_idx + 32) & 0xF;
let q3 = *l.add(j + l_idx + 64) & 0xF;
let q4 = *l.add(j + l_idx + 96) & 0xF;
*ql.add(l_idx) = (q1 | (q3 << 4)) as u8;
*ql.add(l_idx + 32) = (q2 | (q4 << 4)) as u8;
*qh.add(l_idx) = ((*l.add(j + l_idx) >> 4)
| ((*l.add(j + l_idx + 32) >> 4) << 2)
| ((*l.add(j + l_idx + 64) >> 4) << 4)
| ((*l.add(j + l_idx + 96) >> 4) << 6))
as u8;
}
ql = ql.add(64);
qh = qh.add(32);
}
x = x.add(QK_K)
}
}
Ok(())
} }
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067 // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067

View File

@ -22,3 +22,19 @@ pub fn has_mkl() -> bool {
pub fn cuda_is_available() -> bool { pub fn cuda_is_available() -> bool {
cfg!(feature = "cuda") cfg!(feature = "cuda")
} }
pub fn with_avx() -> bool {
cfg!(target_feature = "avx")
}
pub fn with_neon() -> bool {
cfg!(target_feature = "neon")
}
pub fn with_simd128() -> bool {
cfg!(target_feature = "simd128")
}
pub fn with_f16c() -> bool {
cfg!(target_feature = "f16c")
}

View File

@ -145,3 +145,29 @@ fn quantize_q8k() -> Result<()> {
); );
Ok(()) Ok(())
} }
#[test]
fn quantize_q6k() -> Result<()> {
use k_quants::BlockQ6K;
let src = (0..256 * 4)
.map(|v| (v as f32 - 512.) / 1024.)
.collect::<Vec<_>>();
let mut dst = vec![0f32; 256 * 4];
let mut quant = vec![BlockQ6K::zeros(); 4];
BlockQ6K::from_float(&src, &mut quant)?;
BlockQ6K::to_float(&quant, dst.as_mut_slice())?;
assert_eq!(
[src[0], src[128], src[256], src[512], src[800], src[1023]],
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
);
let dst = dst
.iter()
.map(|x| (1000. * x).round() / 1000.)
.collect::<Vec<_>>();
assert_eq!(
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
[-0.497, -0.372, -0.25, -0.0, 0.284, 0.5]
);
Ok(())
}

View File

@ -348,6 +348,14 @@ fn main() -> anyhow::Result<()> {
None None
}; };
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
let mut file = std::fs::File::open(&args.model()?)?; let mut file = std::fs::File::open(&args.model()?)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let model = Content::read(&mut file)?; let model = Content::read(&mut file)?;