mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Cuda kernel for dequantizing q8k. (#1760)
* Cuda kernel for dequantizing q8k. * Clippy lints.
This commit is contained in:
@ -11,15 +11,15 @@ use candle_core::quantized::{QMatMul, QTensor};
|
|||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let device = Device::new_cuda(0)?;
|
let device = Device::new_cuda(0)?;
|
||||||
let q = Tensor::randn(0f32, 1.0, (72, 32), &device)?;
|
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?;
|
||||||
let q_cpu = q.to_device(&Device::Cpu)?;
|
let q_cpu = q.to_device(&Device::Cpu)?;
|
||||||
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q4_0)?;
|
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?;
|
||||||
let q = QMatMul::from_qtensor(q)?;
|
let q = QMatMul::from_qtensor(q)?;
|
||||||
let x = Tensor::randn(0f32, 1.0, (5, 32), &device)?;
|
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?;
|
||||||
let res_q_cuda = q.forward(&x)?;
|
let res_q_cuda = q.forward(&x)?;
|
||||||
println!("{res_q_cuda}");
|
println!("{res_q_cuda}");
|
||||||
|
|
||||||
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q4_0)?;
|
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?;
|
||||||
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?;
|
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?;
|
||||||
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
|
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
|
||||||
let x_cpu = x.to_device(&Device::Cpu)?;
|
let x_cpu = x.to_device(&Device::Cpu)?;
|
||||||
|
@ -36,7 +36,8 @@ fn dequantize(
|
|||||||
GgmlDType::Q4K => ("dequantize_block_q4_K", true),
|
GgmlDType::Q4K => ("dequantize_block_q4_K", true),
|
||||||
GgmlDType::Q5K => ("dequantize_block_q5_K", true),
|
GgmlDType::Q5K => ("dequantize_block_q5_K", true),
|
||||||
GgmlDType::Q6K => ("dequantize_block_q6_K", true),
|
GgmlDType::Q6K => ("dequantize_block_q6_K", true),
|
||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
GgmlDType::Q8K => ("dequantize_block_q8_K", true),
|
||||||
|
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||||
let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
|
let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
|
||||||
@ -115,19 +116,20 @@ impl QCudaStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
|
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
|
||||||
let fast_kernel = match self.dtype {
|
let fast_kernel = matches!(
|
||||||
|
self.dtype,
|
||||||
GgmlDType::Q4_0
|
GgmlDType::Q4_0
|
||||||
| GgmlDType::Q4_1
|
| GgmlDType::Q4_1
|
||||||
| GgmlDType::Q5_0
|
| GgmlDType::Q5_0
|
||||||
| GgmlDType::Q5_1
|
| GgmlDType::Q5_1
|
||||||
| GgmlDType::Q8_0
|
| GgmlDType::Q8_0
|
||||||
| GgmlDType::Q2K
|
| GgmlDType::Q2K
|
||||||
| GgmlDType::Q3K
|
| GgmlDType::Q3K
|
||||||
| GgmlDType::Q4K
|
| GgmlDType::Q4K
|
||||||
| GgmlDType::Q5K
|
| GgmlDType::Q5K
|
||||||
| GgmlDType::Q6K => true,
|
| GgmlDType::Q6K
|
||||||
_ => false,
|
| GgmlDType::Q8K
|
||||||
};
|
);
|
||||||
if fast_kernel {
|
if fast_kernel {
|
||||||
return dequantize(&self.data, self.dtype, elem_count, self.device());
|
return dequantize(&self.data, self.dtype, elem_count, self.device());
|
||||||
}
|
}
|
||||||
@ -229,11 +231,7 @@ impl QCudaStorage {
|
|||||||
storage: &CudaStorage,
|
storage: &CudaStorage,
|
||||||
layout: &crate::Layout,
|
layout: &crate::Layout,
|
||||||
) -> Result<(CudaStorage, crate::Shape)> {
|
) -> Result<(CudaStorage, crate::Shape)> {
|
||||||
let dmmv = match layout.shape().dims() {
|
if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) {
|
||||||
[1, 1, _] | [1, _] => true,
|
|
||||||
_ => false,
|
|
||||||
};
|
|
||||||
if dmmv {
|
|
||||||
self.dequantize_matmul_vec(self_shape, storage, layout)
|
self.dequantize_matmul_vec(self_shape, storage, layout)
|
||||||
} else {
|
} else {
|
||||||
self.dequantize_matmul(self_shape, storage, layout)
|
self.dequantize_matmul(self_shape, storage, layout)
|
||||||
|
@ -224,6 +224,14 @@ typedef struct {
|
|||||||
} block_q6_K;
|
} block_q6_K;
|
||||||
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
|
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
|
||||||
|
|
||||||
|
// In llama.cpp this is only used for intermediate quantization and dot products
|
||||||
|
typedef struct {
|
||||||
|
float d; // delta
|
||||||
|
int8_t qs[QK_K]; // quants
|
||||||
|
int16_t bsums[QK_K/16]; // sum of quants in groups of 16
|
||||||
|
} block_q8_K;
|
||||||
|
static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
|
||||||
|
|
||||||
|
|
||||||
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
||||||
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
||||||
@ -875,6 +883,33 @@ extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, f
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||||
|
const block_q8_K * x = (const block_q8_K *) vx;
|
||||||
|
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
|
||||||
|
#if QK_K == 256
|
||||||
|
// assume 32 threads
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int il = tid/8;
|
||||||
|
const int ir = tid%8;
|
||||||
|
const int n = 8;
|
||||||
|
|
||||||
|
float * y = yy + i*QK_K + 64*il + n*ir;
|
||||||
|
|
||||||
|
const int8_t * q = x[i].qs + 64*il + n*ir;
|
||||||
|
|
||||||
|
for (int l = 0; l < n; ++l) {
|
||||||
|
y[l] = q[l] * x[i].d;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const uint8_t * q = x[i].qs;
|
||||||
|
float * y = yy + i*QK_K;
|
||||||
|
y[tid] = x[i].d * x[i].scales[0];
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||||
static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
|
static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||||
|
Reference in New Issue
Block a user