Fix the block size for some cuda kernels. (#1767)

This commit is contained in:
Laurent Mazare
2024-02-27 14:08:33 +01:00
committed by GitHub
parent 32544a2ad6
commit 6400e1b0a0
2 changed files with 15 additions and 45 deletions

View File

@ -25,26 +25,28 @@ fn dequantize(
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let (kernel_name, is_k) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false),
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false),
GgmlDType::Q5_0 => ("dequantize_block_q5_0", false),
GgmlDType::Q5_1 => ("dequantize_block_q5_1", false),
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false),
GgmlDType::Q2K => ("dequantize_block_q2_K", true),
GgmlDType::Q3K => ("dequantize_block_q3_K", true),
GgmlDType::Q4K => ("dequantize_block_q4_K", true),
GgmlDType::Q5K => ("dequantize_block_q5_K", true),
GgmlDType::Q6K => ("dequantize_block_q6_K", true),
GgmlDType::Q8K => ("dequantize_block_q8_K", true),
let (kernel_name, is_k, block_dim) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32),
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32),
GgmlDType::Q5_0 => ("dequantize_block_q5_0", false, 32),
GgmlDType::Q5_1 => ("dequantize_block_q5_1", false, 32),
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32),
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64),
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64),
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32),
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64),
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64),
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
let nb = (elem_count + 255) / 256;
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (nb as u32, 1, 1),
block_dim: (32, 1, 1),
block_dim: (block_dim, 1, 1),
shared_mem_bytes: 0,
};