mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add a cuda kernel for dequantizing q8_0. (#1804)
This commit is contained in:
@ -738,10 +738,6 @@ macro_rules! quantized_matmul {
|
|||||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||||
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
|
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
|
||||||
fn $fn_name(device: &Device) -> Result<()> {
|
fn $fn_name(device: &Device) -> Result<()> {
|
||||||
if device.is_cuda() {
|
|
||||||
// TODO Enable Cuda GGML sometime maybe.
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
test_matmul(device, (1, 3, 4, 256), $dtype)?;
|
test_matmul(device, (1, 3, 4, 256), $dtype)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -877,6 +877,30 @@ extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, f
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
|
||||||
|
// assume 32 threads
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int il = tid/8;
|
||||||
|
const int ir = tid%8;
|
||||||
|
const int ib = 8*i + ir;
|
||||||
|
if (ib >= nb32) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float * y = yy + 256*i + 32*ir + 8*il;
|
||||||
|
|
||||||
|
const block_q8_0 * x = (const block_q8_0 *)vx + ib;
|
||||||
|
const float d = __half2float(x->d);
|
||||||
|
|
||||||
|
const int8_t * q = x->qs + 8*il;
|
||||||
|
|
||||||
|
for (int l = 0; l < 8; ++l) {
|
||||||
|
y[l] = d * q[l];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||||
const block_q8_K * x = (const block_q8_K *) vx;
|
const block_q8_K * x = (const block_q8_K *) vx;
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user