Cuda kernel for dequantizing q8k. (#1760)

* Cuda kernel for dequantizing q8k.

* Clippy lints.
This commit is contained in:
Laurent Mazare
2024-02-26 08:42:44 +01:00
committed by GitHub
parent 918136ba46
commit badf886583
3 changed files with 55 additions and 22 deletions

View File

@ -224,6 +224,14 @@ typedef struct {
} block_q6_K;
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
// In llama.cpp this is only used for intermediate quantization and dot products
typedef struct {
float d; // delta
int8_t qs[QK_K]; // quants
int16_t bsums[QK_K/16]; // sum of quants in groups of 16
} block_q8_K;
static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
@ -875,6 +883,33 @@ extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, f
#endif
}
extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, float * __restrict__ yy) {
const block_q8_K * x = (const block_q8_K *) vx;
const int i = blockIdx.x;
#if QK_K == 256
// assume 32 threads
const int tid = threadIdx.x;
const int il = tid/8;
const int ir = tid%8;
const int n = 8;
float * y = yy + i*QK_K + 64*il + n*ir;
const int8_t * q = x[i].qs + 64*il + n*ir;
for (int l = 0; l < n; ++l) {
y[l] = q[l] * x[i].d;
}
#else
const int tid = threadIdx.x;
const uint8_t * q = x[i].qs;
float * y = yy + i*QK_K;
y[tid] = x[i].d * x[i].scales[0];
#endif
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {