diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs index 6e078a6e..7929fba6 100644 --- a/candle-core/examples/cuda_basics.rs +++ b/candle-core/examples/cuda_basics.rs @@ -11,15 +11,15 @@ use candle_core::quantized::{QMatMul, QTensor}; fn main() -> Result<()> { let device = Device::new_cuda(0)?; - let q = Tensor::randn(0f32, 1.0, (72, 32), &device)?; + let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?; let q_cpu = q.to_device(&Device::Cpu)?; - let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q4_0)?; + let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?; let q = QMatMul::from_qtensor(q)?; - let x = Tensor::randn(0f32, 1.0, (5, 32), &device)?; + let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?; let res_q_cuda = q.forward(&x)?; println!("{res_q_cuda}"); - let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q4_0)?; + let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?; let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?; let q_cpu = QMatMul::from_qtensor(q_cpu)?; let x_cpu = x.to_device(&Device::Cpu)?; diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index a2fc6655..e44d8093 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -36,7 +36,8 @@ fn dequantize( GgmlDType::Q4K => ("dequantize_block_q4_K", true), GgmlDType::Q5K => ("dequantize_block_q5_K", true), GgmlDType::Q6K => ("dequantize_block_q6_K", true), - _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), + GgmlDType::Q8K => ("dequantize_block_q8_K", true), + _ => crate::bail!("unsupported dtype for dequantize {dtype:?}"), }; let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; let dst = dev.alloc_zeros::(elem_count).w()?; @@ -115,19 +116,20 @@ impl QCudaStorage { } pub fn dequantize(&self, elem_count: usize) -> Result { - let fast_kernel = match self.dtype { + let fast_kernel = matches!( + self.dtype, GgmlDType::Q4_0 - | GgmlDType::Q4_1 - | GgmlDType::Q5_0 - | GgmlDType::Q5_1 - | GgmlDType::Q8_0 - | GgmlDType::Q2K - | GgmlDType::Q3K - | GgmlDType::Q4K - | GgmlDType::Q5K - | GgmlDType::Q6K => true, - _ => false, - }; + | GgmlDType::Q4_1 + | GgmlDType::Q5_0 + | GgmlDType::Q5_1 + | GgmlDType::Q8_0 + | GgmlDType::Q2K + | GgmlDType::Q3K + | GgmlDType::Q4K + | GgmlDType::Q5K + | GgmlDType::Q6K + | GgmlDType::Q8K + ); if fast_kernel { return dequantize(&self.data, self.dtype, elem_count, self.device()); } @@ -229,11 +231,7 @@ impl QCudaStorage { storage: &CudaStorage, layout: &crate::Layout, ) -> Result<(CudaStorage, crate::Shape)> { - let dmmv = match layout.shape().dims() { - [1, 1, _] | [1, _] => true, - _ => false, - }; - if dmmv { + if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) { self.dequantize_matmul_vec(self_shape, storage, layout) } else { self.dequantize_matmul(self_shape, storage, layout) diff --git a/candle-kernels/src/quantized.cu b/candle-kernels/src/quantized.cu index bf81487a..4d32f6fa 100644 --- a/candle-kernels/src/quantized.cu +++ b/candle-kernels/src/quantized.cu @@ -224,6 +224,14 @@ typedef struct { } block_q6_K; static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding"); +// In llama.cpp this is only used for intermediate quantization and dot products +typedef struct { + float d; // delta + int8_t qs[QK_K]; // quants + int16_t bsums[QK_K/16]; // sum of quants in groups of 16 +} block_q8_K; +static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); + // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q @@ -875,6 +883,33 @@ extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, f #endif } +extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, float * __restrict__ yy) { + const block_q8_K * x = (const block_q8_K *) vx; + + const int i = blockIdx.x; + +#if QK_K == 256 + // assume 32 threads + const int tid = threadIdx.x; + const int il = tid/8; + const int ir = tid%8; + const int n = 8; + + float * y = yy + i*QK_K + 64*il + n*ir; + + const int8_t * q = x[i].qs + 64*il + n*ir; + + for (int l = 0; l < n; ++l) { + y[l] = q[l] * x[i].d; + } +#else + const int tid = threadIdx.x; + const uint8_t * q = x[i].qs; + float * y = yy + i*QK_K; + y[tid] = x[i].d * x[i].scales[0]; +#endif +} + template static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {