From 6400e1b0a08d594e1448d522a41bddc98c584313 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 27 Feb 2024 14:08:33 +0100 Subject: [PATCH] Fix the block size for some cuda kernels. (#1767) --- candle-core/src/quantized/cuda.rs | 28 +++++++++++++----------- candle-core/tests/quantized_tests.rs | 32 ---------------------------- 2 files changed, 15 insertions(+), 45 deletions(-) diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index e44d8093..84af483d 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -25,26 +25,28 @@ fn dequantize( ) -> Result { use cudarc::driver::LaunchAsync; - let (kernel_name, is_k) = match dtype { - GgmlDType::Q4_0 => ("dequantize_block_q4_0", false), - GgmlDType::Q4_1 => ("dequantize_block_q4_1", false), - GgmlDType::Q5_0 => ("dequantize_block_q5_0", false), - GgmlDType::Q5_1 => ("dequantize_block_q5_1", false), - GgmlDType::Q8_0 => ("dequantize_block_q8_0", false), - GgmlDType::Q2K => ("dequantize_block_q2_K", true), - GgmlDType::Q3K => ("dequantize_block_q3_K", true), - GgmlDType::Q4K => ("dequantize_block_q4_K", true), - GgmlDType::Q5K => ("dequantize_block_q5_K", true), - GgmlDType::Q6K => ("dequantize_block_q6_K", true), - GgmlDType::Q8K => ("dequantize_block_q8_K", true), + let (kernel_name, is_k, block_dim) = match dtype { + GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32), + GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32), + GgmlDType::Q5_0 => ("dequantize_block_q5_0", false, 32), + GgmlDType::Q5_1 => ("dequantize_block_q5_1", false, 32), + GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32), + GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64), + GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64), + GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32), + GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64), + GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64), + GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32), _ => crate::bail!("unsupported dtype for dequantize {dtype:?}"), }; let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; let dst = dev.alloc_zeros::(elem_count).w()?; let nb = (elem_count + 255) / 256; + // See e.g. + // https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270 let cfg = cudarc::driver::LaunchConfig { grid_dim: (nb as u32, 1, 1), - block_dim: (32, 1, 1), + block_dim: (block_dim, 1, 1), shared_mem_bytes: 0, }; diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index a7811ca5..5f7e4825 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -178,10 +178,6 @@ test_device!( ); fn quantize_q4_0(device: &Device) -> Result<()> { - // TODO Enable this later when we enable cuda. - if device.is_cuda() { - return Ok(()); - } let src = (0..32 * 4).map(|v| v as f32).collect::>(); let src = Tensor::from_slice(&src, (32 * 4,), device)?; @@ -209,10 +205,6 @@ fn quantize_q4_0(device: &Device) -> Result<()> { } fn quantize_q4_1(device: &Device) -> Result<()> { - // TODO Enable this later when we enable cuda. - if device.is_cuda() { - return Ok(()); - } let src = (0..32 * 4).map(|v| v as f32).collect::>(); let src = Tensor::from_slice(&src, (32 * 4,), device)?; let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?; @@ -373,10 +365,6 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3 } fn quantize_q2k(device: &Device) -> Result<()> { - // TODO Enable this later when we enable cuda. - if device.is_cuda() { - return Ok(()); - } let dtype = GgmlDType::Q2K; let src = get_test_vector2(0.5, 1024, device)?; @@ -411,10 +399,6 @@ fn quantize_q2k(device: &Device) -> Result<()> { } fn quantize_q3k(device: &Device) -> Result<()> { - // TODO Enable this later when we enable cuda. - if device.is_cuda() { - return Ok(()); - } let dtype = GgmlDType::Q3K; let src = get_test_vector2(0.5, 1024, device)?; let quant = quantized::QTensor::quantize(&src, dtype)?; @@ -448,10 +432,6 @@ fn quantize_q3k(device: &Device) -> Result<()> { } fn quantize_q4k(device: &Device) -> Result<()> { - // TODO Enable this later when we enable cuda. - if device.is_cuda() { - return Ok(()); - } let dtype = GgmlDType::Q4K; let src = get_test_vector2(0.5, 1024, device)?; let quant = quantized::QTensor::quantize(&src, dtype)?; @@ -485,10 +465,6 @@ fn quantize_q4k(device: &Device) -> Result<()> { } fn quantize_q5k(device: &Device) -> Result<()> { - // TODO Enable this later when we enable cuda. - if device.is_cuda() { - return Ok(()); - } let dtype = GgmlDType::Q5K; let src = get_test_vector2(0.5, 1024, device)?; let quant = quantized::QTensor::quantize(&src, dtype)?; @@ -522,10 +498,6 @@ fn quantize_q5k(device: &Device) -> Result<()> { } fn quantize_q6k(device: &Device) -> Result<()> { - // TODO Enable this later when we enable cuda. - if device.is_cuda() { - return Ok(()); - } let dtype = GgmlDType::Q6K; let src = get_test_vector2(0.5, 1024, device)?; let quant = quantized::QTensor::quantize(&src, dtype)?; @@ -559,10 +531,6 @@ fn quantize_q6k(device: &Device) -> Result<()> { } fn quantize_q8k(device: &Device) -> Result<()> { - // TODO Enable this later when we enable cuda. - if device.is_cuda() { - return Ok(()); - } let dtype = GgmlDType::Q8K; let src = get_test_vector2(0.5, 1024, device)?; let quant = quantized::QTensor::quantize(&src, dtype)?;