diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 366eca1e..28ac896e 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -510,16 +510,59 @@ impl GgmlType for BlockQ8K { type VecDotType = BlockQ8K; fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result { - todo!() + unreachable!() } - fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> { - todo!() + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + let k = xs.len(); + if k % QK_K != 0 { + crate::bail!("quantize_row_q8k: {k} is not divisible by {QK_K}") + } + for (i, y) in ys.iter_mut().enumerate() { + let mut max = 0f32; + let mut amax = 0f32; + let xs = &xs[i * QK_K..(i + 1) * QK_K]; + for &x in xs.iter() { + if amax < x.abs() { + amax = x.abs(); + max = x; + } + } + if amax == 0f32 { + y.d = 0f32; + y.qs.fill(0) + } else { + let iscale = -128f32 / max; + for (j, q) in y.qs.iter_mut().enumerate() { + // ggml uses nearest_int with bit magic here, maybe we want the same + // but we would have to test and benchmark it. + let v = (iscale * xs[j]).round(); + *q = v.min(127.) as i8 + } + for j in 0..QK_K / 16 { + let mut sum = 0i32; + for ii in 0..16 { + sum += y.qs[j * 16 + ii] as i32 + } + y.bsums[j] = sum as i16 + } + y.d = 1.0 / iscale + } + } + Ok(()) } - // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533 - fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> { - todo!() + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK_K != 0 { + crate::bail!("dequantize_row_q8k: {k} is not divisible by {QK_K}") + } + for (i, x) in xs.iter().enumerate() { + for (j, &q) in x.qs.iter().enumerate() { + ys[i * QK_K + j] = x.d * q as f32 + } + } + Ok(()) } } @@ -601,14 +644,14 @@ impl GgmlType for BlockQ4_0 { // Generic implementation. let mut sumf = 0f32; - for i in 0..nb { + for (xs, ys) in xs.iter().zip(ys.iter()) { let mut sum_i = 0; for j in 0..qk / 2 { - let v0 = (xs[i].qs[j] & 0x0F) as i32 - 8; - let v1 = (xs[i].qs[j] >> 4) as i32 - 8; - sum_i += v0 * ys[i].qs[j] as i32 + v1 * ys[i].qs[j + qk / 2] as i32 + let v0 = (xs.qs[j] & 0x0F) as i32 - 8; + let v1 = (xs.qs[j] >> 4) as i32 - 8; + sum_i += v0 * ys.qs[j] as i32 + v1 * ys.qs[j + qk / 2] as i32 } - sumf += sum_i as f32 * f16::to_f32(xs[i].d) * f16::to_f32(ys[i].d) + sumf += sum_i as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d) } Ok(sumf) }