mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Q6K quantization (#495)
* Print the detected arch options. * Add the q6k quantization. * Add a currently broken test. * Bugfix. * Bugfix. * Another bugfix. * Another bugfix + get the test to work.
This commit is contained in:
@ -457,6 +457,143 @@ impl GgmlType for BlockQ5K {
|
||||
}
|
||||
}
|
||||
|
||||
fn nearest_int(v: f32) -> i32 {
|
||||
v.round() as i32
|
||||
}
|
||||
|
||||
unsafe fn make_qx_quants(n: usize, nmax: i32, x: *const f32, ls: *mut i8, rmse_type: i32) -> f32 {
|
||||
let mut max = 0f32;
|
||||
let mut amax = 0f32;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let ax = x.abs();
|
||||
if ax > amax {
|
||||
amax = ax;
|
||||
max = x;
|
||||
}
|
||||
}
|
||||
if amax == 0. {
|
||||
// all zero
|
||||
for i in 0..n {
|
||||
*ls.add(i) = 0;
|
||||
}
|
||||
return 0.;
|
||||
}
|
||||
let mut iscale = -(nmax as f32) / max;
|
||||
if rmse_type == 0 {
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
|
||||
}
|
||||
return 1.0 / iscale;
|
||||
}
|
||||
let weight_type = rmse_type % 2;
|
||||
let mut sumlx = 0f32;
|
||||
let mut suml2 = 0f32;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
let l = l.clamp(-nmax, nmax - 1);
|
||||
*ls.add(i) = (l + nmax) as i8;
|
||||
let w = if weight_type == 1 { x * x } else { 1.0 };
|
||||
let l = l as f32;
|
||||
sumlx += w * x * l;
|
||||
suml2 += w * l * l;
|
||||
}
|
||||
let mut scale = sumlx / suml2;
|
||||
let mut best = scale * sumlx;
|
||||
for _itry in 0..3 {
|
||||
let iscale = 1.0 / scale;
|
||||
let mut slx = 0f32;
|
||||
let mut sl2 = 0f32;
|
||||
let mut changed = false;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
let l = l.clamp(-nmax, nmax - 1);
|
||||
if l + nmax != *ls.add(i) as i32 {
|
||||
changed = true;
|
||||
}
|
||||
let w = if weight_type == 1 { x * x } else { 1f32 };
|
||||
let l = l as f32;
|
||||
slx += w * x * l;
|
||||
sl2 += w * l * l;
|
||||
}
|
||||
if !changed || sl2 == 0.0 || slx * slx <= best * sl2 {
|
||||
break;
|
||||
}
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
|
||||
}
|
||||
sumlx = slx;
|
||||
suml2 = sl2;
|
||||
scale = sumlx / suml2;
|
||||
best = scale * sumlx;
|
||||
}
|
||||
for _itry in 0..5 {
|
||||
let mut n_changed = 0;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let w = if weight_type == 1 { x * x } else { 1. };
|
||||
let l = *ls.add(i) as i32 - nmax;
|
||||
let mut slx = sumlx - w * x * l as f32;
|
||||
if slx > 0. {
|
||||
let mut sl2 = suml2 - w * l as f32 * l as f32;
|
||||
let new_l = nearest_int(x * sl2 / slx);
|
||||
let new_l = new_l.clamp(-nmax, nmax - 1);
|
||||
if new_l != l {
|
||||
slx += w * x * new_l as f32;
|
||||
sl2 += w * new_l as f32 * new_l as f32;
|
||||
if sl2 > 0. && slx * slx * suml2 > sumlx * sumlx * sl2 {
|
||||
*ls.add(i) = (nmax + new_l) as i8;
|
||||
sumlx = slx;
|
||||
suml2 = sl2;
|
||||
scale = sumlx / suml2;
|
||||
best = scale * sumlx;
|
||||
n_changed += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if n_changed == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if rmse_type < 3 {
|
||||
return scale;
|
||||
}
|
||||
for is in -4..4 {
|
||||
if is == 0 {
|
||||
continue;
|
||||
}
|
||||
iscale = -(nmax as f32 + 0.1f32 * is as f32) / max;
|
||||
let mut sumlx = 0.;
|
||||
let mut suml2 = 0.;
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
let l = i32::max(-nmax, i32::min(nmax - 1, l));
|
||||
let w = if weight_type == 1 { x * x } else { 1. };
|
||||
let l = l as f32;
|
||||
sumlx += w * x * l;
|
||||
suml2 += w * l * l;
|
||||
}
|
||||
if suml2 > 0. && sumlx * sumlx > best * suml2 {
|
||||
for i in 0..n {
|
||||
let x = *x.add(i);
|
||||
let l = nearest_int(iscale * x);
|
||||
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
|
||||
}
|
||||
scale = sumlx / suml2;
|
||||
best = scale * sumlx;
|
||||
}
|
||||
}
|
||||
scale
|
||||
}
|
||||
|
||||
impl GgmlType for BlockQ6K {
|
||||
const DTYPE: GgmlDType = GgmlDType::Q6K;
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
@ -524,8 +661,76 @@ impl GgmlType for BlockQ6K {
|
||||
Ok(sums.iter().sum())
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
todo!()
|
||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||
if xs.len() != ys.len() * Self::BLCK_SIZE {
|
||||
crate::bail!(
|
||||
"quantize_row_q6k: size mismatch {} {} {}",
|
||||
xs.len(),
|
||||
ys.len(),
|
||||
Self::BLCK_SIZE
|
||||
)
|
||||
}
|
||||
let mut l = [0i8; QK_K];
|
||||
let mut scales = [0f32; QK_K / 16];
|
||||
let mut x = xs.as_ptr();
|
||||
let l = l.as_mut_ptr();
|
||||
unsafe {
|
||||
for y in ys.iter_mut() {
|
||||
let mut max_scale = 0f32;
|
||||
let mut max_abs_scale = 0f32;
|
||||
for (ib, scale_) in scales.iter_mut().enumerate() {
|
||||
let scale = make_qx_quants(16, 32, x.add(16 * ib), l.add(16 * ib), 1);
|
||||
*scale_ = scale;
|
||||
let abs_scale = scale.abs();
|
||||
if abs_scale > max_abs_scale {
|
||||
max_abs_scale = abs_scale;
|
||||
max_scale = scale
|
||||
}
|
||||
}
|
||||
|
||||
let iscale = -128f32 / max_scale;
|
||||
y.d = f16::from_f32(1.0 / iscale);
|
||||
|
||||
for (y_scale, scale) in y.scales.iter_mut().zip(scales.iter()) {
|
||||
*y_scale = nearest_int(iscale * scale).min(127) as i8
|
||||
}
|
||||
|
||||
for (j, &y_scale) in y.scales.iter().enumerate() {
|
||||
let d = y.d.to_f32() * y_scale as f32;
|
||||
if d == 0. {
|
||||
continue;
|
||||
}
|
||||
for ii in 0..16 {
|
||||
let ll = nearest_int(*x.add(16 * j + ii) / d).clamp(-32, 31);
|
||||
*l.add(16 * j + ii) = (ll + 32) as i8
|
||||
}
|
||||
}
|
||||
|
||||
let mut ql = y.ql.as_mut_ptr();
|
||||
let mut qh = y.qh.as_mut_ptr();
|
||||
|
||||
for j in (0..QK_K).step_by(128) {
|
||||
for l_idx in 0..32 {
|
||||
let q1 = *l.add(j + l_idx) & 0xF;
|
||||
let q2 = *l.add(j + l_idx + 32) & 0xF;
|
||||
let q3 = *l.add(j + l_idx + 64) & 0xF;
|
||||
let q4 = *l.add(j + l_idx + 96) & 0xF;
|
||||
*ql.add(l_idx) = (q1 | (q3 << 4)) as u8;
|
||||
*ql.add(l_idx + 32) = (q2 | (q4 << 4)) as u8;
|
||||
*qh.add(l_idx) = ((*l.add(j + l_idx) >> 4)
|
||||
| ((*l.add(j + l_idx + 32) >> 4) << 2)
|
||||
| ((*l.add(j + l_idx + 64) >> 4) << 4)
|
||||
| ((*l.add(j + l_idx + 96) >> 4) << 6))
|
||||
as u8;
|
||||
}
|
||||
ql = ql.add(64);
|
||||
qh = qh.add(32);
|
||||
}
|
||||
|
||||
x = x.add(QK_K)
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067
|
||||
|
@ -22,3 +22,19 @@ pub fn has_mkl() -> bool {
|
||||
pub fn cuda_is_available() -> bool {
|
||||
cfg!(feature = "cuda")
|
||||
}
|
||||
|
||||
pub fn with_avx() -> bool {
|
||||
cfg!(target_feature = "avx")
|
||||
}
|
||||
|
||||
pub fn with_neon() -> bool {
|
||||
cfg!(target_feature = "neon")
|
||||
}
|
||||
|
||||
pub fn with_simd128() -> bool {
|
||||
cfg!(target_feature = "simd128")
|
||||
}
|
||||
|
||||
pub fn with_f16c() -> bool {
|
||||
cfg!(target_feature = "f16c")
|
||||
}
|
||||
|
@ -145,3 +145,29 @@ fn quantize_q8k() -> Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q6k() -> Result<()> {
|
||||
use k_quants::BlockQ6K;
|
||||
|
||||
let src = (0..256 * 4)
|
||||
.map(|v| (v as f32 - 512.) / 1024.)
|
||||
.collect::<Vec<_>>();
|
||||
let mut dst = vec![0f32; 256 * 4];
|
||||
let mut quant = vec![BlockQ6K::zeros(); 4];
|
||||
BlockQ6K::from_float(&src, &mut quant)?;
|
||||
BlockQ6K::to_float(&quant, dst.as_mut_slice())?;
|
||||
assert_eq!(
|
||||
[src[0], src[128], src[256], src[512], src[800], src[1023]],
|
||||
[-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344]
|
||||
);
|
||||
let dst = dst
|
||||
.iter()
|
||||
.map(|x| (1000. * x).round() / 1000.)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(
|
||||
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
||||
[-0.497, -0.372, -0.25, -0.0, 0.284, 0.5]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -348,6 +348,14 @@ fn main() -> anyhow::Result<()> {
|
||||
None
|
||||
};
|
||||
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
|
||||
let mut file = std::fs::File::open(&args.model()?)?;
|
||||
let start = std::time::Instant::now();
|
||||
let model = Content::read(&mut file)?;
|
||||
|
Reference in New Issue
Block a user