mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Handle Q5_0 and Q5_1 quants in cuda.
This commit is contained in:
@ -16,6 +16,7 @@ pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
|
||||
pub const NWARPS_Q4_0_AMPERE: usize = 4;
|
||||
pub const GGML_CUDA_MMV_X: usize = 32;
|
||||
pub const GGML_CUDA_MMV_Y: usize = 1;
|
||||
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
|
||||
|
||||
fn dequantize(
|
||||
data: &CudaSlice<u8>,
|
||||
@ -25,28 +26,46 @@ fn dequantize(
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let (kernel_name, is_k, block_dim) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32),
|
||||
GgmlDType::Q5_0 => ("dequantize_block_q5_0", false, 32),
|
||||
GgmlDType::Q5_1 => ("dequantize_block_q5_1", false, 32),
|
||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32),
|
||||
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64),
|
||||
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64),
|
||||
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32),
|
||||
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64),
|
||||
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64),
|
||||
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32),
|
||||
let nb = (elem_count + 255) / 256;
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
|
||||
GgmlDType::Q5_0 => {
|
||||
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
|
||||
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
|
||||
(
|
||||
"dequantize_block_q5_0",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
nb,
|
||||
)
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
|
||||
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
|
||||
(
|
||||
"dequantize_block_q5_1",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
nb,
|
||||
)
|
||||
}
|
||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
|
||||
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
|
||||
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
|
||||
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
|
||||
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
|
||||
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
|
||||
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
|
||||
let nb = (elem_count + 255) / 256;
|
||||
// See e.g.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (nb as u32, 1, 1),
|
||||
block_dim: (block_dim, 1, 1),
|
||||
grid_dim: (num_blocks as u32, 1, 1),
|
||||
block_dim: (block_dim as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
@ -54,7 +73,10 @@ fn dequantize(
|
||||
let params = (data, &dst);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
} else {
|
||||
let nb32 = elem_count / 32;
|
||||
let nb32 = match dtype {
|
||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||
_ => elem_count / 32,
|
||||
};
|
||||
let params = (data, &dst, nb32 as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
}
|
||||
|
@ -231,10 +231,6 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
||||
@ -261,10 +257,6 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q5_1(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
||||
|
@ -575,7 +575,7 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
|
||||
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
|
||||
static __device__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
|
||||
const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
|
||||
|
||||
if (i >= k) {
|
||||
@ -595,12 +595,6 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
|
||||
y[iybs + iqs + y_offset] = v.y;
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
|
||||
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
}
|
||||
|
||||
extern "C" __global__ void dequantize_block_q4_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
@ -910,6 +904,14 @@ extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, f
|
||||
#endif
|
||||
}
|
||||
|
||||
extern "C" __global__ void dequantize_block_q5_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
|
||||
return dequantize_block<QK5_0, QR5_0, dequantize_q5_0>(vx, yy, nb32);
|
||||
}
|
||||
|
||||
extern "C" __global__ void dequantize_block_q5_1(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
|
||||
return dequantize_block<QK5_1, QR5_1, dequantize_q5_1>(vx, yy, nb32);
|
||||
}
|
||||
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||
static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||
|
Reference in New Issue
Block a user