Q6K quantization (#495)

* Print the detected arch options.

* Add the q6k quantization.

* Add a currently broken test.

* Bugfix.

* Bugfix.

* Another bugfix.

* Another bugfix + get the test to work.
This commit is contained in:
Laurent Mazare
2023-08-17 22:22:57 +01:00
committed by GitHub
parent fc81af1712
commit 557b2c28dd
4 changed files with 257 additions and 2 deletions

View File

@ -457,6 +457,143 @@ impl GgmlType for BlockQ5K {
}
}
fn nearest_int(v: f32) -> i32 {
v.round() as i32
}
unsafe fn make_qx_quants(n: usize, nmax: i32, x: *const f32, ls: *mut i8, rmse_type: i32) -> f32 {
let mut max = 0f32;
let mut amax = 0f32;
for i in 0..n {
let x = *x.add(i);
let ax = x.abs();
if ax > amax {
amax = ax;
max = x;
}
}
if amax == 0. {
// all zero
for i in 0..n {
*ls.add(i) = 0;
}
return 0.;
}
let mut iscale = -(nmax as f32) / max;
if rmse_type == 0 {
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
}
return 1.0 / iscale;
}
let weight_type = rmse_type % 2;
let mut sumlx = 0f32;
let mut suml2 = 0f32;
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
let l = l.clamp(-nmax, nmax - 1);
*ls.add(i) = (l + nmax) as i8;
let w = if weight_type == 1 { x * x } else { 1.0 };
let l = l as f32;
sumlx += w * x * l;
suml2 += w * l * l;
}
let mut scale = sumlx / suml2;
let mut best = scale * sumlx;
for _itry in 0..3 {
let iscale = 1.0 / scale;
let mut slx = 0f32;
let mut sl2 = 0f32;
let mut changed = false;
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
let l = l.clamp(-nmax, nmax - 1);
if l + nmax != *ls.add(i) as i32 {
changed = true;
}
let w = if weight_type == 1 { x * x } else { 1f32 };
let l = l as f32;
slx += w * x * l;
sl2 += w * l * l;
}
if !changed || sl2 == 0.0 || slx * slx <= best * sl2 {
break;
}
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
}
sumlx = slx;
suml2 = sl2;
scale = sumlx / suml2;
best = scale * sumlx;
}
for _itry in 0..5 {
let mut n_changed = 0;
for i in 0..n {
let x = *x.add(i);
let w = if weight_type == 1 { x * x } else { 1. };
let l = *ls.add(i) as i32 - nmax;
let mut slx = sumlx - w * x * l as f32;
if slx > 0. {
let mut sl2 = suml2 - w * l as f32 * l as f32;
let new_l = nearest_int(x * sl2 / slx);
let new_l = new_l.clamp(-nmax, nmax - 1);
if new_l != l {
slx += w * x * new_l as f32;
sl2 += w * new_l as f32 * new_l as f32;
if sl2 > 0. && slx * slx * suml2 > sumlx * sumlx * sl2 {
*ls.add(i) = (nmax + new_l) as i8;
sumlx = slx;
suml2 = sl2;
scale = sumlx / suml2;
best = scale * sumlx;
n_changed += 1;
}
}
}
}
if n_changed == 0 {
break;
}
}
if rmse_type < 3 {
return scale;
}
for is in -4..4 {
if is == 0 {
continue;
}
iscale = -(nmax as f32 + 0.1f32 * is as f32) / max;
let mut sumlx = 0.;
let mut suml2 = 0.;
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
let l = i32::max(-nmax, i32::min(nmax - 1, l));
let w = if weight_type == 1 { x * x } else { 1. };
let l = l as f32;
sumlx += w * x * l;
suml2 += w * l * l;
}
if suml2 > 0. && sumlx * sumlx > best * suml2 {
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
}
scale = sumlx / suml2;
best = scale * sumlx;
}
}
scale
}
impl GgmlType for BlockQ6K {
const DTYPE: GgmlDType = GgmlDType::Q6K;
const BLCK_SIZE: usize = QK_K;
@ -524,8 +661,76 @@ impl GgmlType for BlockQ6K {
Ok(sums.iter().sum())
}
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
todo!()
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
if xs.len() != ys.len() * Self::BLCK_SIZE {
crate::bail!(
"quantize_row_q6k: size mismatch {} {} {}",
xs.len(),
ys.len(),
Self::BLCK_SIZE
)
}
let mut l = [0i8; QK_K];
let mut scales = [0f32; QK_K / 16];
let mut x = xs.as_ptr();
let l = l.as_mut_ptr();
unsafe {
for y in ys.iter_mut() {
let mut max_scale = 0f32;
let mut max_abs_scale = 0f32;
for (ib, scale_) in scales.iter_mut().enumerate() {
let scale = make_qx_quants(16, 32, x.add(16 * ib), l.add(16 * ib), 1);
*scale_ = scale;
let abs_scale = scale.abs();
if abs_scale > max_abs_scale {
max_abs_scale = abs_scale;
max_scale = scale
}
}
let iscale = -128f32 / max_scale;
y.d = f16::from_f32(1.0 / iscale);
for (y_scale, scale) in y.scales.iter_mut().zip(scales.iter()) {
*y_scale = nearest_int(iscale * scale).min(127) as i8
}
for (j, &y_scale) in y.scales.iter().enumerate() {
let d = y.d.to_f32() * y_scale as f32;
if d == 0. {
continue;
}
for ii in 0..16 {
let ll = nearest_int(*x.add(16 * j + ii) / d).clamp(-32, 31);
*l.add(16 * j + ii) = (ll + 32) as i8
}
}
let mut ql = y.ql.as_mut_ptr();
let mut qh = y.qh.as_mut_ptr();
for j in (0..QK_K).step_by(128) {
for l_idx in 0..32 {
let q1 = *l.add(j + l_idx) & 0xF;
let q2 = *l.add(j + l_idx + 32) & 0xF;
let q3 = *l.add(j + l_idx + 64) & 0xF;
let q4 = *l.add(j + l_idx + 96) & 0xF;
*ql.add(l_idx) = (q1 | (q3 << 4)) as u8;
*ql.add(l_idx + 32) = (q2 | (q4 << 4)) as u8;
*qh.add(l_idx) = ((*l.add(j + l_idx) >> 4)
| ((*l.add(j + l_idx + 32) >> 4) << 2)
| ((*l.add(j + l_idx + 64) >> 4) << 4)
| ((*l.add(j + l_idx + 96) >> 4) << 6))
as u8;
}
ql = ql.add(64);
qh = qh.add(32);
}
x = x.add(QK_K)
}
}
Ok(())
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067