From 08effe376224d2071a7371a7cf85f8899f3a69be Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 15 Aug 2023 18:58:04 +0100 Subject: [PATCH] More quantization support (#455) * Properly initialize wdata. * Simplify the matmul bits. * Add from_float for q4_0. * Fix a couple bugs. * Get the test to work. * Get clippy to be happy. --- candle-core/src/ggml.rs | 157 +++++++++++++++++--------------- candle-core/tests/ggml_tests.rs | 33 +++++++ 2 files changed, 118 insertions(+), 72 deletions(-) create mode 100644 candle-core/tests/ggml_tests.rs diff --git a/candle-core/src/ggml.rs b/candle-core/src/ggml.rs index 0b3dee04..3a41eeec 100644 --- a/candle-core/src/ggml.rs +++ b/candle-core/src/ggml.rs @@ -15,18 +15,24 @@ pub const QK5_1: usize = 32; pub const QK8_0: usize = 32; pub const QK8_1: usize = 32; -pub trait GgmlType: Sized { +pub trait GgmlType: Sized + Clone { const DTYPE: GgmlDType; const BLCK_SIZE: usize; + type VecDotType: GgmlType; + // This is only safe for types that include immediate values such as float/int/... + fn zeros() -> Self { + unsafe { std::mem::MaybeUninit::zeroed().assume_init() } + } fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()>; fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()>; - type VecDotType: GgmlType; - // Dot product used as a building block for quantized mat-mul. + /// Dot product used as a building block for quantized mat-mul. + /// n is the number of elements to be considered. fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result; } +#[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ4_0 { d: f16, @@ -34,6 +40,7 @@ pub struct BlockQ4_0 { } const _: () = assert!(std::mem::size_of::() == 18); +#[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ4_1 { d: f16, @@ -42,6 +49,7 @@ pub struct BlockQ4_1 { } const _: () = assert!(std::mem::size_of::() == 20); +#[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ5_0 { d: f16, @@ -50,6 +58,7 @@ pub struct BlockQ5_0 { } const _: () = assert!(std::mem::size_of::() == 22); +#[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ5_1 { d: f16, @@ -59,6 +68,7 @@ pub struct BlockQ5_1 { } const _: () = assert!(std::mem::size_of::() == 24); +#[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ8_0 { d: f16, @@ -66,6 +76,7 @@ pub struct BlockQ8_0 { } const _: () = assert!(std::mem::size_of::() == 34); +#[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ8_1 { d: f16, @@ -74,6 +85,7 @@ pub struct BlockQ8_1 { } const _: () = assert!(std::mem::size_of::() == 36); +#[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ2K { scales: [u8; QK_K / 16], @@ -83,6 +95,7 @@ pub struct BlockQ2K { } const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::()); +#[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ3K { hmask: [u8; QK_K / 8], @@ -92,6 +105,7 @@ pub struct BlockQ3K { } const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::()); +#[derive(Debug, Clone, PartialEq)] // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82 #[repr(C)] pub struct BlockQ4K { @@ -102,6 +116,7 @@ pub struct BlockQ4K { } const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::()); +#[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ5K { d: f16, @@ -113,6 +128,7 @@ pub struct BlockQ5K { const _: () = assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::()); +#[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ6K { ql: [u8; QK_K / 2], @@ -122,6 +138,7 @@ pub struct BlockQ6K { } const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::()); +#[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ8K { d: f32, @@ -535,8 +552,41 @@ impl GgmlType for BlockQ4_0 { Ok(()) } - fn from_float(_: &[f32], _: &mut [Self]) -> Result<()> { - todo!() + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + // quantize_row_q4_0 + let qk = Self::BLCK_SIZE; + let k = xs.len(); + if k % qk != 0 { + crate::bail!("{k} is not divisible by {}", qk); + }; + let nb = k / qk; + if ys.len() != nb { + crate::bail!("size mismatch {} {} {}", xs.len(), ys.len(), qk,) + } + for (i, ys) in ys.iter_mut().enumerate() { + let mut amax = 0f32; + let mut max = 0f32; + + let xs = &xs[i * qk..(i + 1) * qk]; + for &x in xs.iter() { + if amax < x.abs() { + amax = x.abs(); + max = x; + } + } + let d = max / -8.0; + let id = if d != 0f32 { 1. / d } else { 0. }; + ys.d = f16::from_f32(d); + + for (j, q) in ys.qs.iter_mut().enumerate() { + let x0 = xs[j] * id; + let x1 = xs[qk / 2 + j] * id; + let xi0 = u8::min(15, (x0 + 8.5) as u8); + let xi1 = u8::min(15, (x1 + 8.5) as u8); + *q = xi0 | (xi1 << 4) + } + } + Ok(()) } // https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L2361C10-L2361C122 @@ -555,9 +605,9 @@ impl GgmlType for BlockQ4_0 { for i in 0..nb { let mut sum_i = 0; for j in 0..qk / 2 { - let v0 = (xs[i].qs[j] & 0x0F) - 8; - let v1 = (xs[i].qs[j] >> 4) - 8; - sum_i += v0 * ys[i].qs[j] + v1 * ys[i].qs[j + qk / 2] + let v0 = (xs[i].qs[j] & 0x0F) as i32 - 8; + let v1 = (xs[i].qs[j] >> 4) as i32 - 8; + sum_i += v0 * ys[i].qs[j] as i32 + v1 * ys[i].qs[j + qk / 2] as i32 } sumf += sum_i as f32 * f16::to_f32(xs[i].d) * f16::to_f32(ys[i].d) } @@ -591,7 +641,7 @@ impl GgmlType for BlockQ8_0 { fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { // quantize_row_q8_0 - let k = ys.len(); + let k = xs.len(); if k % Self::BLCK_SIZE != 0 { crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE); }; @@ -608,7 +658,7 @@ impl GgmlType for BlockQ8_0 { let mut amax = 0f32; let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE]; for &x in xs.iter() { - amax = amax.max(x) + amax = amax.max(x.abs()) } let d = amax / ((1 << 7) - 1) as f32; let id = if d != 0f32 { 1. / d } else { 0. }; @@ -644,73 +694,36 @@ impl GgmlType for BlockQ8_1 { } } -const BLCK0: usize = 16; -const BLCK1: usize = 16; - -// This implementation is in-line with the ggml one and keeps the same variable names. // https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605 -pub fn forward_mul_mat(src0: &[T], src1: &[f32], dst: &mut [f32]) -> Result<()> { - // TODO: Use the proper sizes here. - let (ne00, ne01, ne02, ne03) = (1, 1, 1, 1); - let (ne10, ne11, ne12, ne13) = (1, 1, 1, 1); - // The strides are in bytes in ggml, however we use the number of elements in candle. - let (_, nb1, nb2, nb3) = (1, 1, 1, 1); - let (_, nb01, nb02, nb03) = (1, 1, 1, 1); - let (_, nb11, nb12, nb13) = (1, 1, 1, 1); - - let nr0 = ne01; // src0 rows - let nr1 = ne11 * ne12 * ne13; - - // TODO: Either add multi-threading or remove these bits. - let ir010 = 0; - let ir011 = nr0; - let ir110 = 0; - let ir111 = nr1; - let r2 = ne12 / ne02; - let r3 = ne13 / ne03; +pub fn matmul( + mkn: (usize, usize, usize), + lhs: &[f32], + rhs_t: &[T], + dst: &mut [f32], +) -> Result<()> { + let (m, k, n) = mkn; + if m * k != lhs.len() { + crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len()); + } + let k_in_lhs_blocks = (k + T::BLCK_SIZE - 1) / T::BLCK_SIZE; + let k_in_rhs_blocks = (k + T::VecDotType::BLCK_SIZE - 1) / T::VecDotType::BLCK_SIZE; + // TODO: Do not make this copy if the DotType is f32. // TODO: Pre-allocate this. - let wdata = &mut []; - if ne10 % T::BLCK_SIZE != 0 { - crate::bail!( - "forward_mul_mat: ne10 {ne10} is not divisible by block size {}", - T::BLCK_SIZE - ) + let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks]; + for row_idx in 0..m { + let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + let lhs = &lhs[row_idx * k..(row_idx + 1) * k]; + T::VecDotType::from_float(lhs, lhs_b)? } - let row_size = ne10 / T::BLCK_SIZE; - for i13 in 0..ne13 { - for i12 in 0..ne12 { - for i11 in 0..ne11 { - let wdata_idx = i11 + i12 * ne11 + i13 * ne11 * ne12; - let wdata = &mut wdata[wdata_idx..wdata_idx + row_size]; - let src1 = &src1[i13 * nb13 + i12 * nb12 + i11 * nb11..]; - T::VecDotType::from_float(src1, wdata)? - } - } - } - for iir1 in (ir110..ir111).step_by(BLCK1) { - for iir0 in (ir010..ir011).step_by(BLCK0) { - for ir1 in iir1..usize::min(iir1 + BLCK1, ir111) { - let i13 = ir1 / (ne12 * ne11); - let i12 = (ir1 - i13 * ne12 * ne11) / ne11; - let i11 = ir1 - i13 * ne12 * ne11 - i12 * ne11; + let lhs_b = lhs_b.as_slice(); - let i03 = i13 / r3; - let i02 = i12 / r2; - - let i1 = i11; - let i2 = i12; - let i3 = i13; - - let src0_row = &src0[i02 * nb02 + i03 * nb03..]; - let src1_col = &wdata[(i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size..]; - let dst_col = &mut dst[i1 * nb1 + i2 * nb2 + i3 * nb3..]; - for ir0 in iir0..usize::min(iir0 + BLCK0, ir011) { - let src0_row = &src0_row[ir0 * nb01..]; - let v = T::vec_dot(ne00, src0_row, src1_col)?; - dst_col[ir0 - iir0] += v - } - } + for row_idx in 0..m { + let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n]; + for (col_idx, dst) in dst_row.iter_mut().enumerate() { + let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks]; + *dst = T::vec_dot(k, rhs_col, lhs_row)?; } } Ok(()) diff --git a/candle-core/tests/ggml_tests.rs b/candle-core/tests/ggml_tests.rs new file mode 100644 index 00000000..d976ad99 --- /dev/null +++ b/candle-core/tests/ggml_tests.rs @@ -0,0 +1,33 @@ +use candle_core::{ggml, Device, Result, Tensor}; +use ggml::GgmlType; + +#[test] +fn ggml_matmul() -> Result<()> { + let cpu = &Device::Cpu; + let (m, k, n) = (3, 64, 4); + let lhs = (0..(m * k)).map(|v| v as f32).collect::>(); + let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?; + let mut dst = vec![42.; 3 * 4]; + let mut rhs_t = vec![ggml::BlockQ4_0::zeros(); 8]; + let rhs = (0..(k * n)).map(|v| v as f32).collect::>(); + let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?; + ggml::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; + ggml::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?; + assert_eq!( + dst, + &[ + 85120.43, 214561.61, 345454.9, 474748.1, 213474.94, 604465.25, 1000686.4, 1388317.3, + 341875.88, 994283.0, 1655708.8, 2301518.3 + ] + ); + let mm = tensor_lhs.matmul(&tensor_rhs)?; + assert_eq!( + mm.to_vec2::()?, + &[ + [85344.0, 214368.0, 343392.0, 472416.0], + [214368.0, 605536.0, 996704.0, 1387872.0], + [343392.0, 996704.0, 1650016.0, 2303328.0] + ] + ); + Ok(()) +}