Add the cuda dequantize f16 kernels. (#2137)

* Add the cuda dequantize f16 kernels.

* Expose the cuda kernels.

* Add some testing + fix.

* Test the other cases too.

* A few more tests.

* Add an environment variable to enable the dequantize f16 + matmul behavior.
This commit is contained in:
Laurent Mazare
2024-04-28 20:05:05 +02:00
committed by GitHub
parent c68ed8963f
commit eb26e2467e
5 changed files with 317 additions and 55 deletions

View File

@ -765,20 +765,21 @@ static __device__ void dequantize_block(const void * __restrict__ vx, dst_t * __
y[iybs + iqs + y_offset] = v.y;
}
extern "C" __global__ void dequantize_block_q4_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
template<typename dst_t>
static __device__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
const int i = blockIdx.x;
const int64_t i = blockIdx.x;
// assume 32 threads
const int tid = threadIdx.x;
const int il = tid/8;
const int ir = tid%8;
const int ib = 8*i + ir;
const int64_t ib = 8*i + ir;
if (ib >= nb32) {
return;
}
float * y = yy + 256*i + 32*ir + 4*il;
dst_t * y = yy + 256*i + 32*ir + 4*il;
const block_q4_0 * x = (const block_q4_0 *)vx + ib;
const float d = __half2float(x->d);
@ -792,20 +793,21 @@ extern "C" __global__ void dequantize_block_q4_0(const void * __restrict__ vx, f
}
}
extern "C" __global__ void dequantize_block_q4_1(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
template<typename dst_t>
static __device__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
const int i = blockIdx.x;
const int64_t i = blockIdx.x;
// assume 32 threads
const int tid = threadIdx.x;
const int il = tid/8;
const int ir = tid%8;
const int ib = 8*i + ir;
const int64_t ib = 8*i + ir;
if (ib >= nb32) {
return;
}
float * y = yy + 256*i + 32*ir + 4*il;
dst_t * y = yy + 256*i + 32*ir + 4*il;
const block_q4_1 * x = (const block_q4_1 *)vx + ib;
const float2 d = __half22float2(x->dm);
@ -820,7 +822,8 @@ extern "C" __global__ void dequantize_block_q4_1(const void * __restrict__ vx, f
//================================== k-quants
extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float * __restrict__ yy) {
template<typename dst_t>
static __device__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int i = blockIdx.x;
const block_q2_K * x = (const block_q2_K *) vx;
@ -832,7 +835,7 @@ extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, f
const int is = 8*n + l/16;
const uint8_t q = x[i].qs[32*n + l];
float * y = yy + i*QK_K + 128*n;
dst_t * y = yy + i*QK_K + 128*n;
float dall = __low2half(x[i].dm);
float dmin = __high2half(x[i].dm);
@ -844,7 +847,7 @@ extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, f
const int is = tid/16; // 0 or 1
const int il = tid%16; // 0...15
const uint8_t q = x[i].qs[il] >> (2*is);
float * y = yy + i*QK_K + 16*is + il;
dst_t * y = yy + i*QK_K + 16*is + il;
float dall = __low2half(x[i].dm);
float dmin = __high2half(x[i].dm);
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
@ -853,7 +856,8 @@ extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, f
}
extern "C" __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float * __restrict__ yy) {
template<typename dst_t>
static __device__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int i = blockIdx.x;
const block_q3_K * x = (const block_q3_K *) vx;
@ -877,7 +881,7 @@ extern "C" __global__ void dequantize_block_q3_K(const void * __restrict__ vx, f
float d_all = x[i].d;
float dl = d_all * (us - 32);
float * y = yy + i*QK_K + 128*n + 32*j;
dst_t * y = yy + i*QK_K + 128*n + 32*j;
const uint8_t * q = x[i].qs + 32*n;
const uint8_t * hm = x[i].hmask;
@ -889,7 +893,7 @@ extern "C" __global__ void dequantize_block_q3_K(const void * __restrict__ vx, f
const int im = il/8; // 0...1
const int in = il%8; // 0...7
float * y = yy + i*QK_K + 16*is + il;
dst_t * y = yy + i*QK_K + 16*is + il;
const uint8_t q = x[i].qs[il] >> (2*is);
const uint8_t h = x[i].hmask[in] >> (2*is + im);
@ -917,7 +921,8 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
}
#endif
extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float * __restrict__ yy) {
template<typename dst_t>
static __device__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const block_q4_K * x = (const block_q4_K *) vx;
const int i = blockIdx.x;
@ -930,7 +935,7 @@ extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, f
const int is = 2*il;
const int n = 4;
float * y = yy + i*QK_K + 64*il + n*ir;
dst_t * y = yy + i*QK_K + 64*il + n*ir;
const float dall = __low2half(x[i].dm);
const float dmin = __high2half(x[i].dm);
@ -949,7 +954,7 @@ extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, f
#else
const int tid = threadIdx.x;
const uint8_t * q = x[i].qs;
float * y = yy + i*QK_K;
dst_t * y = yy + i*QK_K;
const float d = (float)x[i].dm[0];
const float m = (float)x[i].dm[1];
y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
@ -957,7 +962,8 @@ extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, f
#endif
}
extern "C" __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float * __restrict__ yy) {
template<typename dst_t>
static __device__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const block_q5_K * x = (const block_q5_K *) vx;
const int i = blockIdx.x;
@ -969,7 +975,7 @@ extern "C" __global__ void dequantize_block_q5_K(const void * __restrict__ vx, f
const int ir = tid%16; // ir is in 0...15
const int is = 2*il; // is is in 0...6
float * y = yy + i*QK_K + 64*il + 2*ir;
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
const float dall = __low2half(x[i].dm);
const float dmin = __high2half(x[i].dm);
@ -997,25 +1003,26 @@ extern "C" __global__ void dequantize_block_q5_K(const void * __restrict__ vx, f
const int is = tid/16; // 0 or 1
const uint8_t h = x[i].qh[in] >> im;
const float d = x[i].d;
float * y = yy + i*QK_K + tid;
dst_t * y = yy + i*QK_K + tid;
y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
#endif
}
extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float * __restrict__ yy) {
template<typename dst_t>
static __device__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const block_q6_K * x = (const block_q6_K *) vx;
const int i = blockIdx.x;
const int64_t i = blockIdx.x;
#if QK_K == 256
// assume 64 threads - this is very slightly better than the one below
const int tid = threadIdx.x;
const int ip = tid/32; // ip is 0 or 1
const int il = tid - 32*ip; // 0...32
const int is = 8*ip + il/16;
const int64_t tid = threadIdx.x;
const int64_t ip = tid/32; // ip is 0 or 1
const int64_t il = tid - 32*ip; // 0...32
const int64_t is = 8*ip + il/16;
float * y = yy + i*QK_K + 128*ip + il;
dst_t * y = yy + i*QK_K + 128*ip + il;
const float d = x[i].d;
@ -1030,11 +1037,11 @@ extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, f
#else
// assume 32 threads
const int tid = threadIdx.x;
const int ip = tid/16; // 0 or 1
const int il = tid - 16*ip; // 0...15
const int64_t tid = threadIdx.x;
const int64_t ip = tid/16; // 0 or 1
const int64_t il = tid - 16*ip; // 0...15
float * y = yy + i*QK_K + 16*ip + il;
dst_t * y = yy + i*QK_K + 16*ip + il;
const float d = x[i].d;
@ -1047,7 +1054,8 @@ extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, f
#endif
}
extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
template<typename dst_t>
static __device__ void dequantize_block_q8_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
const int i = blockIdx.x;
// assume 32 threads
@ -1059,7 +1067,7 @@ extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, f
return;
}
float * y = yy + 256*i + 32*ir + 8*il;
dst_t * y = yy + 256*i + 32*ir + 8*il;
const block_q8_0 * x = (const block_q8_0 *)vx + ib;
const float d = __half2float(x->d);
@ -1071,7 +1079,8 @@ extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, f
}
}
extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, float * __restrict__ yy) {
template<typename dst_t>
static __device__ void dequantize_block_q8_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const block_q8_K * x = (const block_q8_K *) vx;
const int i = blockIdx.x;
@ -1083,7 +1092,7 @@ extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, f
const int ir = tid%8;
const int n = 8;
float * y = yy + i*QK_K + 64*il + n*ir;
dst_t * y = yy + i*QK_K + 64*il + n*ir;
const int8_t * q = x[i].qs + 64*il + n*ir;
@ -1098,14 +1107,43 @@ extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, f
#endif
}
extern "C" __global__ void dequantize_block_q5_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
template<typename dst_t>
static __device__ void dequantize_block_q5_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
return dequantize_block<QK5_0, QR5_0, dequantize_q5_0>(vx, yy, nb32);
}
extern "C" __global__ void dequantize_block_q5_1(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
template<typename dst_t>
static __device__ void dequantize_block_q5_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
return dequantize_block<QK5_1, QR5_1, dequantize_q5_1>(vx, yy, nb32);
}
#define DEQUANTIZE_K(QNAME) \
extern "C" __global__ void dequantize_block_##QNAME##_f32(const void * __restrict__ vx, float * __restrict__ y) { \
dequantize_block_##QNAME(vx, y); \
} \
extern "C" __global__ void dequantize_block_##QNAME##_f16(const void * __restrict__ vx, half * __restrict__ y) { \
dequantize_block_##QNAME(vx, y); \
} \
#define DEQUANTIZE(QNAME) \
extern "C" __global__ void dequantize_block_##QNAME##_f32(const void * __restrict__ vx, float * __restrict__ y, const int k) { \
dequantize_block_##QNAME(vx, y, k); \
} \
extern "C" __global__ void dequantize_block_##QNAME##_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) { \
dequantize_block_##QNAME(vx, y, k); \
} \
DEQUANTIZE_K(q2_K)
DEQUANTIZE_K(q3_K)
DEQUANTIZE_K(q4_K)
DEQUANTIZE_K(q5_K)
DEQUANTIZE_K(q6_K)
DEQUANTIZE_K(q8_K)
DEQUANTIZE(q4_0)
DEQUANTIZE(q4_1)
DEQUANTIZE(q5_0)
DEQUANTIZE(q5_1)
DEQUANTIZE(q8_0)
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {