mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add the cuda dequantize f16 kernels. (#2137)
* Add the cuda dequantize f16 kernels. * Expose the cuda kernels. * Add some testing + fix. * Test the other cases too. * A few more tests. * Add an environment variable to enable the dequantize f16 + matmul behavior.
This commit is contained in:
@ -765,20 +765,21 @@ static __device__ void dequantize_block(const void * __restrict__ vx, dst_t * __
|
||||
y[iybs + iqs + y_offset] = v.y;
|
||||
}
|
||||
|
||||
extern "C" __global__ void dequantize_block_q4_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const int64_t i = blockIdx.x;
|
||||
|
||||
// assume 32 threads
|
||||
const int tid = threadIdx.x;
|
||||
const int il = tid/8;
|
||||
const int ir = tid%8;
|
||||
const int ib = 8*i + ir;
|
||||
const int64_t ib = 8*i + ir;
|
||||
if (ib >= nb32) {
|
||||
return;
|
||||
}
|
||||
|
||||
float * y = yy + 256*i + 32*ir + 4*il;
|
||||
dst_t * y = yy + 256*i + 32*ir + 4*il;
|
||||
|
||||
const block_q4_0 * x = (const block_q4_0 *)vx + ib;
|
||||
const float d = __half2float(x->d);
|
||||
@ -792,20 +793,21 @@ extern "C" __global__ void dequantize_block_q4_0(const void * __restrict__ vx, f
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" __global__ void dequantize_block_q4_1(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const int64_t i = blockIdx.x;
|
||||
|
||||
// assume 32 threads
|
||||
const int tid = threadIdx.x;
|
||||
const int il = tid/8;
|
||||
const int ir = tid%8;
|
||||
const int ib = 8*i + ir;
|
||||
const int64_t ib = 8*i + ir;
|
||||
if (ib >= nb32) {
|
||||
return;
|
||||
}
|
||||
|
||||
float * y = yy + 256*i + 32*ir + 4*il;
|
||||
dst_t * y = yy + 256*i + 32*ir + 4*il;
|
||||
|
||||
const block_q4_1 * x = (const block_q4_1 *)vx + ib;
|
||||
const float2 d = __half22float2(x->dm);
|
||||
@ -820,7 +822,8 @@ extern "C" __global__ void dequantize_block_q4_1(const void * __restrict__ vx, f
|
||||
|
||||
//================================== k-quants
|
||||
|
||||
extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_q2_K * x = (const block_q2_K *) vx;
|
||||
@ -832,7 +835,7 @@ extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, f
|
||||
const int is = 8*n + l/16;
|
||||
|
||||
const uint8_t q = x[i].qs[32*n + l];
|
||||
float * y = yy + i*QK_K + 128*n;
|
||||
dst_t * y = yy + i*QK_K + 128*n;
|
||||
|
||||
float dall = __low2half(x[i].dm);
|
||||
float dmin = __high2half(x[i].dm);
|
||||
@ -844,7 +847,7 @@ extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, f
|
||||
const int is = tid/16; // 0 or 1
|
||||
const int il = tid%16; // 0...15
|
||||
const uint8_t q = x[i].qs[il] >> (2*is);
|
||||
float * y = yy + i*QK_K + 16*is + il;
|
||||
dst_t * y = yy + i*QK_K + 16*is + il;
|
||||
float dall = __low2half(x[i].dm);
|
||||
float dmin = __high2half(x[i].dm);
|
||||
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
|
||||
@ -853,7 +856,8 @@ extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, f
|
||||
|
||||
}
|
||||
|
||||
extern "C" __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_q3_K * x = (const block_q3_K *) vx;
|
||||
@ -877,7 +881,7 @@ extern "C" __global__ void dequantize_block_q3_K(const void * __restrict__ vx, f
|
||||
float d_all = x[i].d;
|
||||
float dl = d_all * (us - 32);
|
||||
|
||||
float * y = yy + i*QK_K + 128*n + 32*j;
|
||||
dst_t * y = yy + i*QK_K + 128*n + 32*j;
|
||||
const uint8_t * q = x[i].qs + 32*n;
|
||||
const uint8_t * hm = x[i].hmask;
|
||||
|
||||
@ -889,7 +893,7 @@ extern "C" __global__ void dequantize_block_q3_K(const void * __restrict__ vx, f
|
||||
const int im = il/8; // 0...1
|
||||
const int in = il%8; // 0...7
|
||||
|
||||
float * y = yy + i*QK_K + 16*is + il;
|
||||
dst_t * y = yy + i*QK_K + 16*is + il;
|
||||
|
||||
const uint8_t q = x[i].qs[il] >> (2*is);
|
||||
const uint8_t h = x[i].hmask[in] >> (2*is + im);
|
||||
@ -917,7 +921,8 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
|
||||
}
|
||||
#endif
|
||||
|
||||
extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
const block_q4_K * x = (const block_q4_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
@ -930,7 +935,7 @@ extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, f
|
||||
const int is = 2*il;
|
||||
const int n = 4;
|
||||
|
||||
float * y = yy + i*QK_K + 64*il + n*ir;
|
||||
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
||||
|
||||
const float dall = __low2half(x[i].dm);
|
||||
const float dmin = __high2half(x[i].dm);
|
||||
@ -949,7 +954,7 @@ extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, f
|
||||
#else
|
||||
const int tid = threadIdx.x;
|
||||
const uint8_t * q = x[i].qs;
|
||||
float * y = yy + i*QK_K;
|
||||
dst_t * y = yy + i*QK_K;
|
||||
const float d = (float)x[i].dm[0];
|
||||
const float m = (float)x[i].dm[1];
|
||||
y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
|
||||
@ -957,7 +962,8 @@ extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, f
|
||||
#endif
|
||||
}
|
||||
|
||||
extern "C" __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
const block_q5_K * x = (const block_q5_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
@ -969,7 +975,7 @@ extern "C" __global__ void dequantize_block_q5_K(const void * __restrict__ vx, f
|
||||
const int ir = tid%16; // ir is in 0...15
|
||||
const int is = 2*il; // is is in 0...6
|
||||
|
||||
float * y = yy + i*QK_K + 64*il + 2*ir;
|
||||
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
|
||||
|
||||
const float dall = __low2half(x[i].dm);
|
||||
const float dmin = __high2half(x[i].dm);
|
||||
@ -997,25 +1003,26 @@ extern "C" __global__ void dequantize_block_q5_K(const void * __restrict__ vx, f
|
||||
const int is = tid/16; // 0 or 1
|
||||
const uint8_t h = x[i].qh[in] >> im;
|
||||
const float d = x[i].d;
|
||||
float * y = yy + i*QK_K + tid;
|
||||
dst_t * y = yy + i*QK_K + tid;
|
||||
y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
|
||||
y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
|
||||
#endif
|
||||
}
|
||||
|
||||
extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
const block_q6_K * x = (const block_q6_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const int64_t i = blockIdx.x;
|
||||
#if QK_K == 256
|
||||
|
||||
// assume 64 threads - this is very slightly better than the one below
|
||||
const int tid = threadIdx.x;
|
||||
const int ip = tid/32; // ip is 0 or 1
|
||||
const int il = tid - 32*ip; // 0...32
|
||||
const int is = 8*ip + il/16;
|
||||
const int64_t tid = threadIdx.x;
|
||||
const int64_t ip = tid/32; // ip is 0 or 1
|
||||
const int64_t il = tid - 32*ip; // 0...32
|
||||
const int64_t is = 8*ip + il/16;
|
||||
|
||||
float * y = yy + i*QK_K + 128*ip + il;
|
||||
dst_t * y = yy + i*QK_K + 128*ip + il;
|
||||
|
||||
const float d = x[i].d;
|
||||
|
||||
@ -1030,11 +1037,11 @@ extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, f
|
||||
#else
|
||||
|
||||
// assume 32 threads
|
||||
const int tid = threadIdx.x;
|
||||
const int ip = tid/16; // 0 or 1
|
||||
const int il = tid - 16*ip; // 0...15
|
||||
const int64_t tid = threadIdx.x;
|
||||
const int64_t ip = tid/16; // 0 or 1
|
||||
const int64_t il = tid - 16*ip; // 0...15
|
||||
|
||||
float * y = yy + i*QK_K + 16*ip + il;
|
||||
dst_t * y = yy + i*QK_K + 16*ip + il;
|
||||
|
||||
const float d = x[i].d;
|
||||
|
||||
@ -1047,7 +1054,8 @@ extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, f
|
||||
#endif
|
||||
}
|
||||
|
||||
extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q8_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
||||
const int i = blockIdx.x;
|
||||
|
||||
// assume 32 threads
|
||||
@ -1059,7 +1067,7 @@ extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, f
|
||||
return;
|
||||
}
|
||||
|
||||
float * y = yy + 256*i + 32*ir + 8*il;
|
||||
dst_t * y = yy + 256*i + 32*ir + 8*il;
|
||||
|
||||
const block_q8_0 * x = (const block_q8_0 *)vx + ib;
|
||||
const float d = __half2float(x->d);
|
||||
@ -1071,7 +1079,8 @@ extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, f
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q8_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
const block_q8_K * x = (const block_q8_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
@ -1083,7 +1092,7 @@ extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, f
|
||||
const int ir = tid%8;
|
||||
const int n = 8;
|
||||
|
||||
float * y = yy + i*QK_K + 64*il + n*ir;
|
||||
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
||||
|
||||
const int8_t * q = x[i].qs + 64*il + n*ir;
|
||||
|
||||
@ -1098,14 +1107,43 @@ extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, f
|
||||
#endif
|
||||
}
|
||||
|
||||
extern "C" __global__ void dequantize_block_q5_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q5_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
||||
return dequantize_block<QK5_0, QR5_0, dequantize_q5_0>(vx, yy, nb32);
|
||||
}
|
||||
|
||||
extern "C" __global__ void dequantize_block_q5_1(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q5_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
||||
return dequantize_block<QK5_1, QR5_1, dequantize_q5_1>(vx, yy, nb32);
|
||||
}
|
||||
|
||||
#define DEQUANTIZE_K(QNAME) \
|
||||
extern "C" __global__ void dequantize_block_##QNAME##_f32(const void * __restrict__ vx, float * __restrict__ y) { \
|
||||
dequantize_block_##QNAME(vx, y); \
|
||||
} \
|
||||
extern "C" __global__ void dequantize_block_##QNAME##_f16(const void * __restrict__ vx, half * __restrict__ y) { \
|
||||
dequantize_block_##QNAME(vx, y); \
|
||||
} \
|
||||
|
||||
#define DEQUANTIZE(QNAME) \
|
||||
extern "C" __global__ void dequantize_block_##QNAME##_f32(const void * __restrict__ vx, float * __restrict__ y, const int k) { \
|
||||
dequantize_block_##QNAME(vx, y, k); \
|
||||
} \
|
||||
extern "C" __global__ void dequantize_block_##QNAME##_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) { \
|
||||
dequantize_block_##QNAME(vx, y, k); \
|
||||
} \
|
||||
|
||||
DEQUANTIZE_K(q2_K)
|
||||
DEQUANTIZE_K(q3_K)
|
||||
DEQUANTIZE_K(q4_K)
|
||||
DEQUANTIZE_K(q5_K)
|
||||
DEQUANTIZE_K(q6_K)
|
||||
DEQUANTIZE_K(q8_K)
|
||||
DEQUANTIZE(q4_0)
|
||||
DEQUANTIZE(q4_1)
|
||||
DEQUANTIZE(q5_0)
|
||||
DEQUANTIZE(q5_1)
|
||||
DEQUANTIZE(q8_0)
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||
static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||
|
Reference in New Issue
Block a user