Handle Q5_0 and Q5_1 quants in cuda.

This commit is contained in:
laurent
2024-02-29 10:54:01 +01:00
parent 4fd00b8900
commit 2c95b7394a
3 changed files with 47 additions and 31 deletions

View File

@ -575,7 +575,7 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
static __device__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
if (i >= k) {
@ -595,12 +595,6 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
y[iybs + iqs + y_offset] = v.y;
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}
extern "C" __global__ void dequantize_block_q4_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
const int i = blockIdx.x;
@ -910,6 +904,14 @@ extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, f
#endif
}
extern "C" __global__ void dequantize_block_q5_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
return dequantize_block<QK5_0, QR5_0, dequantize_q5_0>(vx, yy, nb32);
}
extern "C" __global__ void dequantize_block_q5_1(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
return dequantize_block<QK5_1, QR5_1, dequantize_q5_1>(vx, yy, nb32);
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {