From 8de0ce6cba823c53344ebdee028a13f8d564dee0 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 18 Apr 2024 08:36:43 +0200 Subject: [PATCH] Add more QMMV cuda kernels. (#2077) * Add more QMMV cuda kernels. * Enable the new kernels. * Adapt the testing. --- candle-core/src/quantized/cuda.rs | 18 +- candle-core/tests/quantized_tests.rs | 22 +- candle-kernels/src/quantized.cu | 324 +++++++++++++++++++++++++++ 3 files changed, 349 insertions(+), 15 deletions(-) diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index d6a61682..5481ca3c 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -178,8 +178,8 @@ fn mul_mat_vec_via_q8_1( if y.len() != ncols * b_size { crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len()) } - if b_size == 0 || b_size > 4 { - crate::bail!("only bsize between 1 and 4 are supported, got {b_size}") + if b_size == 0 || b_size > 8 { + crate::bail!("only bsize between 1 and 8 are supported, got {b_size}") } // Start by quantizing y let ncols_padded = pad(ncols, MATRIX_ROW_PADDING); @@ -204,14 +204,16 @@ fn mul_mat_vec_via_q8_1( let kernel_name = format!("{kernel_name}{b_size}"); let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?; let dst = unsafe { dev.alloc::(nrows * b_size).w()? }; - let nblocks = if b_size == 1 { - nrows as u32 - } else { - (nrows as u32 + 1) / 2 + // https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98 + let (nblocks, nwarps) = match b_size { + 1 => (nrows as u32, 4), + 2..=4 => ((nrows as u32 + 1) / 2, 4), + 5..=8 => ((nrows as u32 + 1) / 2, 2), + _ => crate::bail!("unexpected bsize {b_size}"), }; let cfg = cudarc::driver::LaunchConfig { grid_dim: (nblocks, 1, 1), - block_dim: (WARP_SIZE as u32, 4, 1), + block_dim: (WARP_SIZE as u32, nwarps, 1), shared_mem_bytes: 0, }; @@ -398,7 +400,7 @@ impl QCudaStorage { let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) { 1 } else { - 4 + 8 }; let use_vec_kernel = match layout.shape().dims() { [b, m, _k] => b * m <= max_bm, diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 157f2f8d..b2a64ac9 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -193,17 +193,25 @@ fn qmm_batch(dev: &Device) -> Result<()> { let mm3 = rhs.forward(&lhs3)?; assert_eq!(mm3.shape().dims(), [6, 6]); let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::()?; - if dev.is_cuda() { - assert!(diff3 < 1e-4) - } else { - assert_eq!(diff3, 0.0) - }; + assert_eq!(diff3, 0.0); let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::()?; + assert_eq!(diff3, 0.0); + let lhs4 = Tensor::cat(&[&lhs3, &lhs3], 0)?; + let mm4 = rhs.forward(&lhs4)?; + assert_eq!(mm4.shape().dims(), [12, 6]); + let diff4 = (mm4.i(..6)? - &mm3)?.abs()?.sum_all()?.to_vec0::()?; if dev.is_cuda() { - assert!(diff3 < 1e-4) + // We use a different kernel for sizes from 1 to 8 on cuda which explains + // the difference here. + assert!(0. < diff4 && diff4 < 1e-4) } else { - assert_eq!(diff3, 0.0) + assert_eq!(diff4, 0.0) }; + let diff4 = (mm4.i(6..)? - &mm4.i(..6)?)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff4, 0.0); Ok(()) } diff --git a/candle-kernels/src/quantized.cu b/candle-kernels/src/quantized.cu index 7e3e7b4c..c5bc4563 100644 --- a/candle-kernels/src/quantized.cu +++ b/candle-kernels/src/quantized.cu @@ -2972,6 +2972,330 @@ extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda4( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } +// batch size = 5 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +// batch size = 6 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +// batch size = 7 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +// batch size = 8 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + extern "C" __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) { const int ix = blockDim.x*blockIdx.x + threadIdx.x;