mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
More quantization support (#455)
* Properly initialize wdata. * Simplify the matmul bits. * Add from_float for q4_0. * Fix a couple bugs. * Get the test to work. * Get clippy to be happy.
This commit is contained in:
@ -15,18 +15,24 @@ pub const QK5_1: usize = 32;
|
|||||||
pub const QK8_0: usize = 32;
|
pub const QK8_0: usize = 32;
|
||||||
pub const QK8_1: usize = 32;
|
pub const QK8_1: usize = 32;
|
||||||
|
|
||||||
pub trait GgmlType: Sized {
|
pub trait GgmlType: Sized + Clone {
|
||||||
const DTYPE: GgmlDType;
|
const DTYPE: GgmlDType;
|
||||||
const BLCK_SIZE: usize;
|
const BLCK_SIZE: usize;
|
||||||
|
type VecDotType: GgmlType;
|
||||||
|
|
||||||
|
// This is only safe for types that include immediate values such as float/int/...
|
||||||
|
fn zeros() -> Self {
|
||||||
|
unsafe { std::mem::MaybeUninit::zeroed().assume_init() }
|
||||||
|
}
|
||||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()>;
|
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()>;
|
||||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()>;
|
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()>;
|
||||||
|
|
||||||
type VecDotType: GgmlType;
|
/// Dot product used as a building block for quantized mat-mul.
|
||||||
// Dot product used as a building block for quantized mat-mul.
|
/// n is the number of elements to be considered.
|
||||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
|
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
pub struct BlockQ4_0 {
|
pub struct BlockQ4_0 {
|
||||||
d: f16,
|
d: f16,
|
||||||
@ -34,6 +40,7 @@ pub struct BlockQ4_0 {
|
|||||||
}
|
}
|
||||||
const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == 18);
|
const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == 18);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
pub struct BlockQ4_1 {
|
pub struct BlockQ4_1 {
|
||||||
d: f16,
|
d: f16,
|
||||||
@ -42,6 +49,7 @@ pub struct BlockQ4_1 {
|
|||||||
}
|
}
|
||||||
const _: () = assert!(std::mem::size_of::<BlockQ4_1>() == 20);
|
const _: () = assert!(std::mem::size_of::<BlockQ4_1>() == 20);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
pub struct BlockQ5_0 {
|
pub struct BlockQ5_0 {
|
||||||
d: f16,
|
d: f16,
|
||||||
@ -50,6 +58,7 @@ pub struct BlockQ5_0 {
|
|||||||
}
|
}
|
||||||
const _: () = assert!(std::mem::size_of::<BlockQ5_0>() == 22);
|
const _: () = assert!(std::mem::size_of::<BlockQ5_0>() == 22);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
pub struct BlockQ5_1 {
|
pub struct BlockQ5_1 {
|
||||||
d: f16,
|
d: f16,
|
||||||
@ -59,6 +68,7 @@ pub struct BlockQ5_1 {
|
|||||||
}
|
}
|
||||||
const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24);
|
const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
pub struct BlockQ8_0 {
|
pub struct BlockQ8_0 {
|
||||||
d: f16,
|
d: f16,
|
||||||
@ -66,6 +76,7 @@ pub struct BlockQ8_0 {
|
|||||||
}
|
}
|
||||||
const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
|
const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
pub struct BlockQ8_1 {
|
pub struct BlockQ8_1 {
|
||||||
d: f16,
|
d: f16,
|
||||||
@ -74,6 +85,7 @@ pub struct BlockQ8_1 {
|
|||||||
}
|
}
|
||||||
const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
|
const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
pub struct BlockQ2K {
|
pub struct BlockQ2K {
|
||||||
scales: [u8; QK_K / 16],
|
scales: [u8; QK_K / 16],
|
||||||
@ -83,6 +95,7 @@ pub struct BlockQ2K {
|
|||||||
}
|
}
|
||||||
const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::<BlockQ2K>());
|
const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::<BlockQ2K>());
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
pub struct BlockQ3K {
|
pub struct BlockQ3K {
|
||||||
hmask: [u8; QK_K / 8],
|
hmask: [u8; QK_K / 8],
|
||||||
@ -92,6 +105,7 @@ pub struct BlockQ3K {
|
|||||||
}
|
}
|
||||||
const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ3K>());
|
const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ3K>());
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82
|
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
pub struct BlockQ4K {
|
pub struct BlockQ4K {
|
||||||
@ -102,6 +116,7 @@ pub struct BlockQ4K {
|
|||||||
}
|
}
|
||||||
const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::<BlockQ4K>());
|
const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::<BlockQ4K>());
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
pub struct BlockQ5K {
|
pub struct BlockQ5K {
|
||||||
d: f16,
|
d: f16,
|
||||||
@ -113,6 +128,7 @@ pub struct BlockQ5K {
|
|||||||
const _: () =
|
const _: () =
|
||||||
assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::<BlockQ5K>());
|
assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::<BlockQ5K>());
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
pub struct BlockQ6K {
|
pub struct BlockQ6K {
|
||||||
ql: [u8; QK_K / 2],
|
ql: [u8; QK_K / 2],
|
||||||
@ -122,6 +138,7 @@ pub struct BlockQ6K {
|
|||||||
}
|
}
|
||||||
const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::<BlockQ6K>());
|
const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::<BlockQ6K>());
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
pub struct BlockQ8K {
|
pub struct BlockQ8K {
|
||||||
d: f32,
|
d: f32,
|
||||||
@ -535,8 +552,41 @@ impl GgmlType for BlockQ4_0 {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn from_float(_: &[f32], _: &mut [Self]) -> Result<()> {
|
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||||
todo!()
|
// quantize_row_q4_0
|
||||||
|
let qk = Self::BLCK_SIZE;
|
||||||
|
let k = xs.len();
|
||||||
|
if k % qk != 0 {
|
||||||
|
crate::bail!("{k} is not divisible by {}", qk);
|
||||||
|
};
|
||||||
|
let nb = k / qk;
|
||||||
|
if ys.len() != nb {
|
||||||
|
crate::bail!("size mismatch {} {} {}", xs.len(), ys.len(), qk,)
|
||||||
|
}
|
||||||
|
for (i, ys) in ys.iter_mut().enumerate() {
|
||||||
|
let mut amax = 0f32;
|
||||||
|
let mut max = 0f32;
|
||||||
|
|
||||||
|
let xs = &xs[i * qk..(i + 1) * qk];
|
||||||
|
for &x in xs.iter() {
|
||||||
|
if amax < x.abs() {
|
||||||
|
amax = x.abs();
|
||||||
|
max = x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let d = max / -8.0;
|
||||||
|
let id = if d != 0f32 { 1. / d } else { 0. };
|
||||||
|
ys.d = f16::from_f32(d);
|
||||||
|
|
||||||
|
for (j, q) in ys.qs.iter_mut().enumerate() {
|
||||||
|
let x0 = xs[j] * id;
|
||||||
|
let x1 = xs[qk / 2 + j] * id;
|
||||||
|
let xi0 = u8::min(15, (x0 + 8.5) as u8);
|
||||||
|
let xi1 = u8::min(15, (x1 + 8.5) as u8);
|
||||||
|
*q = xi0 | (xi1 << 4)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L2361C10-L2361C122
|
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L2361C10-L2361C122
|
||||||
@ -555,9 +605,9 @@ impl GgmlType for BlockQ4_0 {
|
|||||||
for i in 0..nb {
|
for i in 0..nb {
|
||||||
let mut sum_i = 0;
|
let mut sum_i = 0;
|
||||||
for j in 0..qk / 2 {
|
for j in 0..qk / 2 {
|
||||||
let v0 = (xs[i].qs[j] & 0x0F) - 8;
|
let v0 = (xs[i].qs[j] & 0x0F) as i32 - 8;
|
||||||
let v1 = (xs[i].qs[j] >> 4) - 8;
|
let v1 = (xs[i].qs[j] >> 4) as i32 - 8;
|
||||||
sum_i += v0 * ys[i].qs[j] + v1 * ys[i].qs[j + qk / 2]
|
sum_i += v0 * ys[i].qs[j] as i32 + v1 * ys[i].qs[j + qk / 2] as i32
|
||||||
}
|
}
|
||||||
sumf += sum_i as f32 * f16::to_f32(xs[i].d) * f16::to_f32(ys[i].d)
|
sumf += sum_i as f32 * f16::to_f32(xs[i].d) * f16::to_f32(ys[i].d)
|
||||||
}
|
}
|
||||||
@ -591,7 +641,7 @@ impl GgmlType for BlockQ8_0 {
|
|||||||
|
|
||||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||||
// quantize_row_q8_0
|
// quantize_row_q8_0
|
||||||
let k = ys.len();
|
let k = xs.len();
|
||||||
if k % Self::BLCK_SIZE != 0 {
|
if k % Self::BLCK_SIZE != 0 {
|
||||||
crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE);
|
crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE);
|
||||||
};
|
};
|
||||||
@ -608,7 +658,7 @@ impl GgmlType for BlockQ8_0 {
|
|||||||
let mut amax = 0f32;
|
let mut amax = 0f32;
|
||||||
let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
|
let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
|
||||||
for &x in xs.iter() {
|
for &x in xs.iter() {
|
||||||
amax = amax.max(x)
|
amax = amax.max(x.abs())
|
||||||
}
|
}
|
||||||
let d = amax / ((1 << 7) - 1) as f32;
|
let d = amax / ((1 << 7) - 1) as f32;
|
||||||
let id = if d != 0f32 { 1. / d } else { 0. };
|
let id = if d != 0f32 { 1. / d } else { 0. };
|
||||||
@ -644,73 +694,36 @@ impl GgmlType for BlockQ8_1 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const BLCK0: usize = 16;
|
|
||||||
const BLCK1: usize = 16;
|
|
||||||
|
|
||||||
// This implementation is in-line with the ggml one and keeps the same variable names.
|
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605
|
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605
|
||||||
pub fn forward_mul_mat<T: GgmlType>(src0: &[T], src1: &[f32], dst: &mut [f32]) -> Result<()> {
|
pub fn matmul<T: GgmlType>(
|
||||||
// TODO: Use the proper sizes here.
|
mkn: (usize, usize, usize),
|
||||||
let (ne00, ne01, ne02, ne03) = (1, 1, 1, 1);
|
lhs: &[f32],
|
||||||
let (ne10, ne11, ne12, ne13) = (1, 1, 1, 1);
|
rhs_t: &[T],
|
||||||
// The strides are in bytes in ggml, however we use the number of elements in candle.
|
dst: &mut [f32],
|
||||||
let (_, nb1, nb2, nb3) = (1, 1, 1, 1);
|
) -> Result<()> {
|
||||||
let (_, nb01, nb02, nb03) = (1, 1, 1, 1);
|
let (m, k, n) = mkn;
|
||||||
let (_, nb11, nb12, nb13) = (1, 1, 1, 1);
|
if m * k != lhs.len() {
|
||||||
|
crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len());
|
||||||
let nr0 = ne01; // src0 rows
|
}
|
||||||
let nr1 = ne11 * ne12 * ne13;
|
|
||||||
|
|
||||||
// TODO: Either add multi-threading or remove these bits.
|
|
||||||
let ir010 = 0;
|
|
||||||
let ir011 = nr0;
|
|
||||||
let ir110 = 0;
|
|
||||||
let ir111 = nr1;
|
|
||||||
let r2 = ne12 / ne02;
|
|
||||||
let r3 = ne13 / ne03;
|
|
||||||
|
|
||||||
|
let k_in_lhs_blocks = (k + T::BLCK_SIZE - 1) / T::BLCK_SIZE;
|
||||||
|
let k_in_rhs_blocks = (k + T::VecDotType::BLCK_SIZE - 1) / T::VecDotType::BLCK_SIZE;
|
||||||
|
// TODO: Do not make this copy if the DotType is f32.
|
||||||
// TODO: Pre-allocate this.
|
// TODO: Pre-allocate this.
|
||||||
let wdata = &mut [];
|
let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks];
|
||||||
if ne10 % T::BLCK_SIZE != 0 {
|
for row_idx in 0..m {
|
||||||
crate::bail!(
|
let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
|
||||||
"forward_mul_mat: ne10 {ne10} is not divisible by block size {}",
|
let lhs = &lhs[row_idx * k..(row_idx + 1) * k];
|
||||||
T::BLCK_SIZE
|
T::VecDotType::from_float(lhs, lhs_b)?
|
||||||
)
|
|
||||||
}
|
}
|
||||||
let row_size = ne10 / T::BLCK_SIZE;
|
let lhs_b = lhs_b.as_slice();
|
||||||
for i13 in 0..ne13 {
|
|
||||||
for i12 in 0..ne12 {
|
|
||||||
for i11 in 0..ne11 {
|
|
||||||
let wdata_idx = i11 + i12 * ne11 + i13 * ne11 * ne12;
|
|
||||||
let wdata = &mut wdata[wdata_idx..wdata_idx + row_size];
|
|
||||||
let src1 = &src1[i13 * nb13 + i12 * nb12 + i11 * nb11..];
|
|
||||||
T::VecDotType::from_float(src1, wdata)?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for iir1 in (ir110..ir111).step_by(BLCK1) {
|
|
||||||
for iir0 in (ir010..ir011).step_by(BLCK0) {
|
|
||||||
for ir1 in iir1..usize::min(iir1 + BLCK1, ir111) {
|
|
||||||
let i13 = ir1 / (ne12 * ne11);
|
|
||||||
let i12 = (ir1 - i13 * ne12 * ne11) / ne11;
|
|
||||||
let i11 = ir1 - i13 * ne12 * ne11 - i12 * ne11;
|
|
||||||
|
|
||||||
let i03 = i13 / r3;
|
for row_idx in 0..m {
|
||||||
let i02 = i12 / r2;
|
let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
|
||||||
|
let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n];
|
||||||
let i1 = i11;
|
for (col_idx, dst) in dst_row.iter_mut().enumerate() {
|
||||||
let i2 = i12;
|
let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks];
|
||||||
let i3 = i13;
|
*dst = T::vec_dot(k, rhs_col, lhs_row)?;
|
||||||
|
|
||||||
let src0_row = &src0[i02 * nb02 + i03 * nb03..];
|
|
||||||
let src1_col = &wdata[(i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size..];
|
|
||||||
let dst_col = &mut dst[i1 * nb1 + i2 * nb2 + i3 * nb3..];
|
|
||||||
for ir0 in iir0..usize::min(iir0 + BLCK0, ir011) {
|
|
||||||
let src0_row = &src0_row[ir0 * nb01..];
|
|
||||||
let v = T::vec_dot(ne00, src0_row, src1_col)?;
|
|
||||||
dst_col[ir0 - iir0] += v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
33
candle-core/tests/ggml_tests.rs
Normal file
33
candle-core/tests/ggml_tests.rs
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
use candle_core::{ggml, Device, Result, Tensor};
|
||||||
|
use ggml::GgmlType;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ggml_matmul() -> Result<()> {
|
||||||
|
let cpu = &Device::Cpu;
|
||||||
|
let (m, k, n) = (3, 64, 4);
|
||||||
|
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
|
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
|
||||||
|
let mut dst = vec![42.; 3 * 4];
|
||||||
|
let mut rhs_t = vec![ggml::BlockQ4_0::zeros(); 8];
|
||||||
|
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
|
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
|
||||||
|
ggml::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||||
|
ggml::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||||
|
assert_eq!(
|
||||||
|
dst,
|
||||||
|
&[
|
||||||
|
85120.43, 214561.61, 345454.9, 474748.1, 213474.94, 604465.25, 1000686.4, 1388317.3,
|
||||||
|
341875.88, 994283.0, 1655708.8, 2301518.3
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
||||||
|
assert_eq!(
|
||||||
|
mm.to_vec2::<f32>()?,
|
||||||
|
&[
|
||||||
|
[85344.0, 214368.0, 343392.0, 472416.0],
|
||||||
|
[214368.0, 605536.0, 996704.0, 1387872.0],
|
||||||
|
[343392.0, 996704.0, 1650016.0, 2303328.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
Reference in New Issue
Block a user