From f7d5bf5b97071c5bb299084559992e4681fcf277 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 15 Apr 2024 08:32:47 +0200 Subject: [PATCH] Faster kernels for quantized matmul on cuda (#2060) * Hook the quantized matmul cuda kernels. * Add a (currently broken) test. * Kernel fixes. * Fix by transposing the rhs matrix. * Add the q4-1 kernels. * Proper block sizes. * More details in the tests. --- candle-core/src/quantized/cuda.rs | 143 ++++++++++++++++++++++++++++-- candle-kernels/src/quantized.cu | 129 ++++++++++++++++++++++++--- 2 files changed, 255 insertions(+), 17 deletions(-) diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 07f8c13e..487431f6 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -40,6 +40,7 @@ fn quantize_q8_1( src: &CudaView, dst: &mut CudaSlice, elem_count: usize, + ky: usize, dev: &CudaDevice, ) -> Result<()> { use cudarc::driver::LaunchAsync; @@ -49,7 +50,7 @@ fn quantize_q8_1( let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE); let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?; let cfg = cudarc::driver::LaunchConfig { - grid_dim: (num_blocks as u32, 1, 1), + grid_dim: (num_blocks as u32, ky as u32, 1), block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1), shared_mem_bytes: 0, }; @@ -180,7 +181,7 @@ fn mul_mat_vec_via_q8_1( let ncols_padded = pad(ncols, MATRIX_ROW_PADDING); let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); let mut y_q8_1 = unsafe { dev.alloc::(y_size_in_bytes).w()? }; - quantize_q8_1(y, &mut y_q8_1, ncols, dev)?; + quantize_q8_1(y, &mut y_q8_1, ncols, 1, dev)?; let kernel_name = match dtype { GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda", @@ -216,6 +217,75 @@ fn mul_mat_vec_via_q8_1( Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) } +fn mul_mat_via_q8_1( + data: &CudaSlice, + y: &CudaView, + dtype: GgmlDType, + x_rows: usize, + x_cols: usize, + y_rows: usize, + y_cols: usize, + dev: &CudaDevice, +) -> Result { + use cudarc::driver::LaunchAsync; + + let data_elems = data.len() / dtype.type_size() * dtype.block_size(); + if data_elems < x_rows * x_cols { + crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems) + } + if y.len() != y_rows * y_cols { + crate::bail!("unexpected y size {}, {y_rows} {y_cols}", y.len()) + } + if x_cols != y_rows { + crate::bail!("unexpected x/y size {x_rows} {x_cols} {y_rows} {y_cols}") + } + let k = x_cols; + // Start by quantizing y + let k_padded = pad(k, MATRIX_ROW_PADDING); + let y_size_in_bytes = + k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); + let mut y_q8_1 = unsafe { dev.alloc::(y_size_in_bytes).w()? }; + quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?; + + let (kernel_name, mmq_x, mmq_y) = match dtype { + GgmlDType::Q4_0 => ("mul_mat_q4_0", 64, 128), + GgmlDType::Q4_1 => ("mul_mat_q4_1", 64, 128), + GgmlDType::Q5_0 => ("mul_mat_q5_0", 128, 64), + GgmlDType::Q5_1 => ("mul_mat_q5_1", 128, 64), + GgmlDType::Q8_0 => ("mul_mat_q8_0", 128, 64), + GgmlDType::Q2K => ("mul_mat_q2_K", 64, 128), + GgmlDType::Q3K => ("mul_mat_q3_K", 128, 128), + GgmlDType::Q4K => ("mul_mat_q4_K", 64, 128), + GgmlDType::Q5K => ("mul_mat_q5_K", 64, 128), + GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64), + _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), + }; + let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; + let dst = unsafe { dev.alloc::(x_rows * y_cols).w()? }; + let cfg = cudarc::driver::LaunchConfig { + grid_dim: ( + ceil_div(x_rows, mmq_y) as u32, + ceil_div(y_cols, mmq_x) as u32, + 1, + ), + block_dim: (WARP_SIZE as u32, 4, 1), + shared_mem_bytes: 0, + }; + + let params = ( + /* vx */ data, + /* vy */ &y_q8_1, + /* dst */ &dst, + /* ncols_x */ x_cols as i32, + /* nrows_x */ x_rows as i32, + /* ncols_y */ y_cols as i32, + /* nrows_y */ k_padded as i32, + /* nrows_dst */ x_rows as i32, + ); + unsafe { func.launch(cfg, params) }.w()?; + Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) +} + impl QCudaStorage { pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result { let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size(); @@ -373,9 +443,30 @@ impl QCudaStorage { crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape()) } - let data_f32 = self.dequantize(n * k)?; - let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?; - let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?; + let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) { + let data_f32 = self.dequantize(n * k)?; + let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?; + storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)? + } else { + let storage = storage.as_cuda_slice::()?; + let storage = match layout.contiguous_offsets() { + Some((o1, o2)) => storage.slice(o1..o2), + None => Err(crate::Error::RequiresContiguous { + op: "quantized-matmul", + } + .bt())?, + }; + mul_mat_via_q8_1( + &self.data, + &storage, + self.dtype, + /* x_rows */ n, + /* x_cols */ k, + /* y_rows */ k, + /* y_cols */ m, + self.device(), + )? + }; let mut out_shape = layout.shape().dims().to_vec(); out_shape.pop(); out_shape.push(n); @@ -412,7 +503,7 @@ mod test { let mut y_q8_1 = unsafe { dev.alloc::(y_size_in_bytes).w()? }; let vs: Vec = (0..el).map(|v| v as f32).collect(); let y = dev.htod_sync_copy(&vs).w()?; - quantize_q8_1(&y.slice(..), &mut y_q8_1, el, &dev)?; + quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?; Ok(()) } @@ -453,4 +544,44 @@ mod test { assert_eq!(vs[0], 5561851.0); Ok(()) } + + #[test] + fn cuda_mm_q8_1() -> Result<()> { + let dev = CudaDevice::new(0)?; + let ncols = 256; + let vs: Vec = (0..ncols * 4).map(|v| v as f32 / 4.).collect(); + let y = dev.htod_sync_copy(&vs).w()?; + let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?; + xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?; + let cuda_storage = mul_mat_via_q8_1( + &xs.data, + &y.slice(..), + /* dtype */ GgmlDType::Q4_0, + /* x_rows */ 4, + /* x_cols */ ncols, + /* y_rows */ ncols, + /* y_cols */ 4, + &dev, + )?; + let vs = cuda_storage.as_cuda_slice::()?; + let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap(); + + /* + x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256) + x @ x.t() / 16 + tensor([[ 347480.0000, 869720.0000, 1391960.0000, 1914200.0000], + [ 869720.0000, 2440536.0000, 4011352.0000, 5582166.5000], + [ 1391960.0000, 4011352.0000, 6630742.0000, 9250132.0000], + [ 1914200.0000, 5582166.5000, 9250132.0000, 12918099.0000]]) + */ + assert_eq!(vs.len(), 16); + assert_eq!(vs[0], 347604.0); + assert_eq!(vs[1], 888153.06); + assert_eq!(vs[4], 869780.7); + assert_eq!(vs[5], 2483145.0); + assert_eq!(vs[11], 9407368.0); + assert_eq!(vs[14], 9470856.0); + assert_eq!(vs[15], 13138824.0); + Ok(()) + } } diff --git a/candle-kernels/src/quantized.cu b/candle-kernels/src/quantized.cu index c3ce9568..fa38f325 100644 --- a/candle-kernels/src/quantized.cu +++ b/candle-kernels/src/quantized.cu @@ -71,8 +71,6 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * } -#define CUDA_USE_TENSOR_CORES - #define WARP_SIZE 32 #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) @@ -103,6 +101,25 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * #define MMQ_Y_Q4_0_PASCAL 64 #define NWARPS_Q4_0_PASCAL 8 +#define MMQ_X_Q4_1_RDNA2 64 +#define MMQ_Y_Q4_1_RDNA2 128 +#define NWARPS_Q4_1_RDNA2 8 +#define MMQ_X_Q4_1_RDNA1 64 +#define MMQ_Y_Q4_1_RDNA1 64 +#define NWARPS_Q4_1_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q4_1_AMPERE 4 +#define MMQ_Y_Q4_1_AMPERE 32 +#define NWARPS_Q4_1_AMPERE 4 +#else +#define MMQ_X_Q4_1_AMPERE 64 +#define MMQ_Y_Q4_1_AMPERE 128 +#define NWARPS_Q4_1_AMPERE 4 +#endif +#define MMQ_X_Q4_1_PASCAL 64 +#define MMQ_Y_Q4_1_PASCAL 64 +#define NWARPS_Q4_1_PASCAL 8 + #define MMQ_X_Q5_0_RDNA2 64 #define MMQ_Y_Q5_0_RDNA2 128 #define NWARPS_Q5_0_RDNA2 8 @@ -558,6 +575,52 @@ template static __device__ __forceinlin } } +template static __device__ __forceinline__ void load_tiles_q4_1( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI4_1; + const int kqsx = k % QI4_1; + + const block_q4_1 * bx0 = (const block_q4_1 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) { + int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; + } +} + + template static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { (void)x_qh; (void)x_sc; @@ -568,6 +631,16 @@ template static __device__ __forceinline__ void allocate_tiles_q4_0( *x_dm = (half2 *) tile_x_d; } +template static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1]; + + *x_ql = tile_x_qs; + *x_dm = tile_x_dm; +} + static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; @@ -3493,6 +3566,26 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); } +static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + + int u[2*VDR_Q4_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE]; + } + + return vec_dot_q4_1_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1], + y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + extern "C" __global__ void mul_mat_q4_0( @@ -3503,10 +3596,24 @@ extern "C" __global__ void const int nwarps = NWARPS_Q4_0_AMPERE; mul_mat_q, - load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> + load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } +extern "C" __global__ void + mul_mat_q4_1( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q4_1_AMPERE; + const int mmq_y = MMQ_Y_Q4_1_AMPERE; + const int nwarps = NWARPS_Q4_1_AMPERE; + + mul_mat_q, + load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + + extern "C" __global__ void mul_mat_q5_0( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, @@ -3516,7 +3623,7 @@ extern "C" __global__ void const int nwarps = NWARPS_Q5_0_AMPERE; mul_mat_q, - load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> + load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } @@ -3529,7 +3636,7 @@ mul_mat_q5_1( const int nwarps = NWARPS_Q5_1_AMPERE; mul_mat_q, - load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> + load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } @@ -3542,7 +3649,7 @@ extern "C" __global__ void const int nwarps = NWARPS_Q8_0_AMPERE; mul_mat_q, - load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> + load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } @@ -3554,7 +3661,7 @@ mul_mat_q2_K( const int mmq_y = MMQ_Y_Q2_K_AMPERE; const int nwarps = NWARPS_Q2_K_AMPERE; mul_mat_q, - load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> + load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } @@ -3566,7 +3673,7 @@ extern "C" __global__ void const int mmq_y = MMQ_Y_Q3_K_AMPERE; const int nwarps = NWARPS_Q3_K_AMPERE; mul_mat_q, - load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> + load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } @@ -3578,7 +3685,7 @@ extern "C" __global__ void const int mmq_y = MMQ_Y_Q4_K_AMPERE; const int nwarps = NWARPS_Q4_K_AMPERE; mul_mat_q, - load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> + load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } @@ -3590,7 +3697,7 @@ mul_mat_q5_K( const int mmq_y = MMQ_Y_Q5_K_AMPERE; const int nwarps = NWARPS_Q5_K_AMPERE; mul_mat_q, - load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> + load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } @@ -3602,6 +3709,6 @@ extern "C" __global__ void const int mmq_y = MMQ_Y_Q6_K_AMPERE; const int nwarps = NWARPS_Q6_K_AMPERE; mul_mat_q, - load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> + load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); }