mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Cuda kernel for dequantizing q8k. (#1760)
* Cuda kernel for dequantizing q8k. * Clippy lints.
This commit is contained in:
@ -224,6 +224,14 @@ typedef struct {
|
||||
} block_q6_K;
|
||||
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
|
||||
|
||||
// In llama.cpp this is only used for intermediate quantization and dot products
|
||||
typedef struct {
|
||||
float d; // delta
|
||||
int8_t qs[QK_K]; // quants
|
||||
int16_t bsums[QK_K/16]; // sum of quants in groups of 16
|
||||
} block_q8_K;
|
||||
static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
|
||||
|
||||
|
||||
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
||||
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
||||
@ -875,6 +883,33 @@ extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, f
|
||||
#endif
|
||||
}
|
||||
|
||||
extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||
const block_q8_K * x = (const block_q8_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
|
||||
#if QK_K == 256
|
||||
// assume 32 threads
|
||||
const int tid = threadIdx.x;
|
||||
const int il = tid/8;
|
||||
const int ir = tid%8;
|
||||
const int n = 8;
|
||||
|
||||
float * y = yy + i*QK_K + 64*il + n*ir;
|
||||
|
||||
const int8_t * q = x[i].qs + 64*il + n*ir;
|
||||
|
||||
for (int l = 0; l < n; ++l) {
|
||||
y[l] = q[l] * x[i].d;
|
||||
}
|
||||
#else
|
||||
const int tid = threadIdx.x;
|
||||
const uint8_t * q = x[i].qs;
|
||||
float * y = yy + i*QK_K;
|
||||
y[tid] = x[i].d * x[i].scales[0];
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||
static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||
|
Reference in New Issue
Block a user