Cuda kernel for dequantizing q8k. (#1760)

* Cuda kernel for dequantizing q8k.

* Clippy lints.
This commit is contained in:
Laurent Mazare
2024-02-26 08:42:44 +01:00
committed by GitHub
parent 918136ba46
commit badf886583
3 changed files with 55 additions and 22 deletions

View File

@ -11,15 +11,15 @@ use candle_core::quantized::{QMatMul, QTensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let q = Tensor::randn(0f32, 1.0, (72, 32), &device)?;
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?;
let q_cpu = q.to_device(&Device::Cpu)?;
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q4_0)?;
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?;
let q = QMatMul::from_qtensor(q)?;
let x = Tensor::randn(0f32, 1.0, (5, 32), &device)?;
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?;
let res_q_cuda = q.forward(&x)?;
println!("{res_q_cuda}");
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q4_0)?;
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?;
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?;
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
let x_cpu = x.to_device(&Device::Cpu)?;

View File

@ -36,7 +36,8 @@ fn dequantize(
GgmlDType::Q4K => ("dequantize_block_q4_K", true),
GgmlDType::Q5K => ("dequantize_block_q5_K", true),
GgmlDType::Q6K => ("dequantize_block_q6_K", true),
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
GgmlDType::Q8K => ("dequantize_block_q8_K", true),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
@ -115,19 +116,20 @@ impl QCudaStorage {
}
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
let fast_kernel = match self.dtype {
let fast_kernel = matches!(
self.dtype,
GgmlDType::Q4_0
| GgmlDType::Q4_1
| GgmlDType::Q5_0
| GgmlDType::Q5_1
| GgmlDType::Q8_0
| GgmlDType::Q2K
| GgmlDType::Q3K
| GgmlDType::Q4K
| GgmlDType::Q5K
| GgmlDType::Q6K => true,
_ => false,
};
| GgmlDType::Q4_1
| GgmlDType::Q5_0
| GgmlDType::Q5_1
| GgmlDType::Q8_0
| GgmlDType::Q2K
| GgmlDType::Q3K
| GgmlDType::Q4K
| GgmlDType::Q5K
| GgmlDType::Q6K
| GgmlDType::Q8K
);
if fast_kernel {
return dequantize(&self.data, self.dtype, elem_count, self.device());
}
@ -229,11 +231,7 @@ impl QCudaStorage {
storage: &CudaStorage,
layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
let dmmv = match layout.shape().dims() {
[1, 1, _] | [1, _] => true,
_ => false,
};
if dmmv {
if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) {
self.dequantize_matmul_vec(self_shape, storage, layout)
} else {
self.dequantize_matmul(self_shape, storage, layout)