Add the cuda dequantize f16 kernels. (#2137)

* Add the cuda dequantize f16 kernels.

* Expose the cuda kernels.

* Add some testing + fix.

* Test the other cases too.

* A few more tests.

* Add an environment variable to enable the dequantize f16 + matmul behavior.
This commit is contained in:
Laurent Mazare
2024-04-28 20:05:05 +02:00
committed by GitHub
parent c68ed8963f
commit eb26e2467e
5 changed files with 317 additions and 55 deletions

View File

@ -2,6 +2,7 @@ use super::{GgmlDType, QStorage};
use crate::quantized::k_quants::GgmlType; use crate::quantized::k_quants::GgmlType;
use crate::{backend::BackendDevice, cuda_backend::WrapErr}; use crate::{backend::BackendDevice, cuda_backend::WrapErr};
use crate::{CudaDevice, CudaStorage, Result}; use crate::{CudaDevice, CudaStorage, Result};
use half::f16;
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice}; use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
@ -59,7 +60,7 @@ fn quantize_q8_1(
Ok(()) Ok(())
} }
fn dequantize( fn dequantize_f32(
data: &CudaSlice<u8>, data: &CudaSlice<u8>,
dtype: GgmlDType, dtype: GgmlDType,
elem_count: usize, elem_count: usize,
@ -69,27 +70,27 @@ fn dequantize(
let nb = (elem_count + 255) / 256; let nb = (elem_count + 255) / 256;
let (kernel_name, is_k, block_dim, num_blocks) = match dtype { let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb), GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb), GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
GgmlDType::Q5_0 => ( GgmlDType::Q5_0 => (
"dequantize_block_q5_0", "dequantize_block_q5_0_f32",
false, false,
CUDA_DEQUANTIZE_BLOCK_SIZE, CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE), ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
), ),
GgmlDType::Q5_1 => ( GgmlDType::Q5_1 => (
"dequantize_block_q5_1", "dequantize_block_q5_1_f32",
false, false,
CUDA_DEQUANTIZE_BLOCK_SIZE, CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE), ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
), ),
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb), GgmlDType::Q8_0 => ("dequantize_block_q8_0_f32", false, 32, nb),
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb), GgmlDType::Q2K => ("dequantize_block_q2_K_f32", true, 64, nb),
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb), GgmlDType::Q3K => ("dequantize_block_q3_K_f32", true, 64, nb),
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb), GgmlDType::Q4K => ("dequantize_block_q4_K_f32", true, 32, nb),
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb), GgmlDType::Q5K => ("dequantize_block_q5_K_f32", true, 64, nb),
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb), GgmlDType::Q6K => ("dequantize_block_q6_K_f32", true, 64, nb),
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb), GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"), _ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
}; };
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
@ -116,6 +117,63 @@ fn dequantize(
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
} }
fn dequantize_f16(
data: &CudaSlice<u8>,
dtype: GgmlDType,
elem_count: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let nb = (elem_count + 255) / 256;
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
GgmlDType::Q5_0 => (
"dequantize_block_q5_0_f16",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q5_1 => (
"dequantize_block_q5_1_f16",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f16", false, 32, nb),
GgmlDType::Q2K => ("dequantize_block_q2_K_f16", true, 64, nb),
GgmlDType::Q3K => ("dequantize_block_q3_K_f16", true, 64, nb),
GgmlDType::Q4K => ("dequantize_block_q4_K_f16", true, 32, nb),
GgmlDType::Q5K => ("dequantize_block_q5_K_f16", true, 64, nb),
GgmlDType::Q6K => ("dequantize_block_q6_K_f16", true, 64, nb),
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, 1, 1),
block_dim: (block_dim as u32, 1, 1),
shared_mem_bytes: 0,
};
if is_k {
let params = (data, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (data, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_mul_mat_vec( fn dequantize_mul_mat_vec(
data: &CudaSlice<u8>, data: &CudaSlice<u8>,
y: &CudaView<f32>, y: &CudaView<f32>,
@ -341,7 +399,7 @@ impl QCudaStorage {
| GgmlDType::Q8K | GgmlDType::Q8K
); );
if fast_kernel { if fast_kernel {
return dequantize(&self.data, self.dtype, elem_count, self.device()); return dequantize_f32(&self.data, self.dtype, elem_count, self.device());
} }
// Run the dequantization on cpu. // Run the dequantization on cpu.
@ -369,6 +427,10 @@ impl QCudaStorage {
.storage_from_cpu_storage(&crate::CpuStorage::F32(out)) .storage_from_cpu_storage(&crate::CpuStorage::F32(out))
} }
pub fn dequantize_f16(&self, elem_count: usize) -> Result<CudaStorage> {
dequantize_f16(&self.data, self.dtype, elem_count, self.device())
}
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> { pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
// Run the quantization on cpu. // Run the quantization on cpu.
let src = match &src.slice { let src = match &src.slice {

View File

@ -24,6 +24,10 @@ impl QCudaStorage {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
pub fn dequantize_f16(&self, _elem_count: usize) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> { pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }

View File

@ -1,4 +1,4 @@
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor}; use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
use k_quants::*; use k_quants::*;
use std::borrow::Cow; use std::borrow::Cow;
@ -360,10 +360,25 @@ impl QTensor {
pub fn dequantize(&self, device: &Device) -> Result<Tensor> { pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
let storage = self.storage.dequantize(self.shape.elem_count())?; let storage = self.storage.dequantize(self.shape.elem_count())?;
let none = crate::op::BackpropOp::none(); let none = crate::op::BackpropOp::none();
let is_variable = false; crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable) }
pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
// In the CUDA case, we have a specialized kernel as this can be useful for volta
// architectures. https://github.com/huggingface/candle/issues/2136
match &self.storage {
QStorage::Cuda(s) => {
let s = s.dequantize_f16(self.shape.elem_count())?;
let none = crate::op::BackpropOp::none();
crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
.to_device(device) .to_device(device)
} }
_ => {
let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
Ok(s)
}
}
}
pub fn storage_size_in_bytes(&self) -> usize { pub fn storage_size_in_bytes(&self) -> usize {
self.storage.size_in_bytes() self.storage.size_in_bytes()
@ -378,6 +393,7 @@ impl QTensor {
pub enum QMatMul { pub enum QMatMul {
QTensor(std::sync::Arc<QTensor>), QTensor(std::sync::Arc<QTensor>),
Tensor(Tensor), Tensor(Tensor),
TensorF16(Tensor),
} }
thread_local! { thread_local! {
@ -391,6 +407,17 @@ thread_local! {
} }
} }
thread_local! {
static DEQUANTIZE_ALL_F16: bool = {
match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
impl QMatMul { impl QMatMul {
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> { pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
let dequantize = match qtensor.dtype() { let dequantize = match qtensor.dtype() {
@ -400,6 +427,9 @@ impl QMatMul {
let t = if dequantize { let t = if dequantize {
let tensor = qtensor.dequantize(&qtensor.device())?; let tensor = qtensor.dequantize(&qtensor.device())?;
Self::Tensor(tensor) Self::Tensor(tensor)
} else if DEQUANTIZE_ALL_F16.with(|b| *b) {
let tensor = qtensor.dequantize_f16(&qtensor.device())?;
Self::TensorF16(tensor)
} else { } else {
Self::QTensor(qtensor) Self::QTensor(qtensor)
}; };
@ -486,6 +516,15 @@ impl crate::Module for QMatMul {
}; };
xs.matmul(&w) xs.matmul(&w)
} }
Self::TensorF16(w) => {
let in_dtype = xs.dtype();
let w = match *xs.dims() {
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
_ => w.t()?,
};
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
}
} }
} }
} }

View File

@ -3,7 +3,7 @@ use candle_core::{
quantized::{self, GgmlDType}, quantized::{self, GgmlDType},
test_device, test_device,
test_utils::to_vec2_round, test_utils::to_vec2_round,
Device, IndexOp, Module, Result, Tensor, DType, Device, IndexOp, Module, Result, Tensor,
}; };
use quantized::{k_quants, GgmlType}; use quantized::{k_quants, GgmlType};
use rand::prelude::*; use rand::prelude::*;
@ -225,6 +225,13 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?; let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?;
let dst = quant.dequantize(device)?; let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
assert_eq!( assert_eq!(
dst.to_vec1::<f32>()?, dst.to_vec1::<f32>()?,
&[ &[
@ -251,6 +258,13 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?; let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
let dst = quant.dequantize(device)?; let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
assert_eq!( assert_eq!(
round_vector(&dst.to_vec1::<f32>()?), round_vector(&dst.to_vec1::<f32>()?),
&[ &[
@ -277,6 +291,13 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?; let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
let dst = quant.dequantize(device)?; let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
assert_eq!( assert_eq!(
round_vector(&dst.to_vec1::<f32>()?), round_vector(&dst.to_vec1::<f32>()?),
&[ &[
@ -303,6 +324,13 @@ fn quantize_q5_1(device: &Device) -> Result<()> {
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?; let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
let dst = quant.dequantize(device)?; let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
assert_eq!( assert_eq!(
round_vector(&dst.to_vec1::<f32>()?), round_vector(&dst.to_vec1::<f32>()?),
&[ &[
@ -387,6 +415,13 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?; let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?; let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?; let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?); let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
if error > max_error { if error > max_error {
bail!( bail!(
@ -404,6 +439,13 @@ fn quantize_q2k(device: &Device) -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?; let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?; let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?; let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src = src.to_vec1::<f32>()?; let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?; let dst = dst.to_vec1::<f32>()?;
@ -423,6 +465,13 @@ fn quantize_q2k(device: &Device) -> Result<()> {
let src_big = get_test_vector2(128.0, 1024, device)?; let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?; let dst_big = quant_big.dequantize(device)?;
let dst_big_f16 = quant_big.dequantize_f16(device)?;
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src_big = src_big.to_vec1::<f32>()?; let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?; let dst_big = dst_big.to_vec1::<f32>()?;
@ -437,6 +486,13 @@ fn quantize_q3k(device: &Device) -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?; let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?; let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?; let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src = src.to_vec1::<f32>()?; let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?; let dst = dst.to_vec1::<f32>()?;
@ -456,6 +512,13 @@ fn quantize_q3k(device: &Device) -> Result<()> {
let src_big = get_test_vector2(128.0, 1024, device)?; let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?; let dst_big = quant_big.dequantize(device)?;
let dst_big_f16 = quant_big.dequantize_f16(device)?;
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src_big = src_big.to_vec1::<f32>()?; let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?; let dst_big = dst_big.to_vec1::<f32>()?;
@ -470,6 +533,13 @@ fn quantize_q4k(device: &Device) -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?; let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?; let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?; let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src = src.to_vec1::<f32>()?; let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?; let dst = dst.to_vec1::<f32>()?;
@ -489,6 +559,13 @@ fn quantize_q4k(device: &Device) -> Result<()> {
let src_big = get_test_vector2(128.0, 1024, device)?; let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?; let dst_big = quant_big.dequantize(device)?;
let dst_big_f16 = quant_big.dequantize_f16(device)?;
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src_big = src_big.to_vec1::<f32>()?; let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?; let dst_big = dst_big.to_vec1::<f32>()?;
@ -503,6 +580,13 @@ fn quantize_q5k(device: &Device) -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?; let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?; let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?; let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src = src.to_vec1::<f32>()?; let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?; let dst = dst.to_vec1::<f32>()?;
@ -522,6 +606,13 @@ fn quantize_q5k(device: &Device) -> Result<()> {
let src_big = get_test_vector2(128.0, 1024, device)?; let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?; let dst_big = quant_big.dequantize(device)?;
let dst_big_f16 = quant_big.dequantize_f16(device)?;
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src_big = src_big.to_vec1::<f32>()?; let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?; let dst_big = dst_big.to_vec1::<f32>()?;
@ -536,6 +627,13 @@ fn quantize_q6k(device: &Device) -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?; let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?; let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?; let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src = src.to_vec1::<f32>()?; let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?; let dst = dst.to_vec1::<f32>()?;
@ -555,6 +653,13 @@ fn quantize_q6k(device: &Device) -> Result<()> {
let src_big = get_test_vector2(128.0, 1024, device)?; let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?; let dst_big = quant_big.dequantize(device)?;
let dst_big_f16 = quant_big.dequantize_f16(device)?;
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src_big = src_big.to_vec1::<f32>()?; let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?; let dst_big = dst_big.to_vec1::<f32>()?;
@ -569,6 +674,13 @@ fn quantize_q8k(device: &Device) -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?; let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?; let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?; let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src = src.to_vec1::<f32>()?; let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?; let dst = dst.to_vec1::<f32>()?;
@ -588,6 +700,13 @@ fn quantize_q8k(device: &Device) -> Result<()> {
let src_big = get_test_vector2(128.0, 1024, device)?; let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?; let dst_big = quant_big.dequantize(device)?;
let dst_big_f16 = quant_big.dequantize_f16(device)?;
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src_big = src_big.to_vec1::<f32>()?; let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?; let dst_big = dst_big.to_vec1::<f32>()?;

View File

@ -765,20 +765,21 @@ static __device__ void dequantize_block(const void * __restrict__ vx, dst_t * __
y[iybs + iqs + y_offset] = v.y; y[iybs + iqs + y_offset] = v.y;
} }
extern "C" __global__ void dequantize_block_q4_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) { template<typename dst_t>
static __device__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
const int i = blockIdx.x; const int64_t i = blockIdx.x;
// assume 32 threads // assume 32 threads
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int il = tid/8; const int il = tid/8;
const int ir = tid%8; const int ir = tid%8;
const int ib = 8*i + ir; const int64_t ib = 8*i + ir;
if (ib >= nb32) { if (ib >= nb32) {
return; return;
} }
float * y = yy + 256*i + 32*ir + 4*il; dst_t * y = yy + 256*i + 32*ir + 4*il;
const block_q4_0 * x = (const block_q4_0 *)vx + ib; const block_q4_0 * x = (const block_q4_0 *)vx + ib;
const float d = __half2float(x->d); const float d = __half2float(x->d);
@ -792,20 +793,21 @@ extern "C" __global__ void dequantize_block_q4_0(const void * __restrict__ vx, f
} }
} }
extern "C" __global__ void dequantize_block_q4_1(const void * __restrict__ vx, float * __restrict__ yy, int nb32) { template<typename dst_t>
static __device__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
const int i = blockIdx.x; const int64_t i = blockIdx.x;
// assume 32 threads // assume 32 threads
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int il = tid/8; const int il = tid/8;
const int ir = tid%8; const int ir = tid%8;
const int ib = 8*i + ir; const int64_t ib = 8*i + ir;
if (ib >= nb32) { if (ib >= nb32) {
return; return;
} }
float * y = yy + 256*i + 32*ir + 4*il; dst_t * y = yy + 256*i + 32*ir + 4*il;
const block_q4_1 * x = (const block_q4_1 *)vx + ib; const block_q4_1 * x = (const block_q4_1 *)vx + ib;
const float2 d = __half22float2(x->dm); const float2 d = __half22float2(x->dm);
@ -820,7 +822,8 @@ extern "C" __global__ void dequantize_block_q4_1(const void * __restrict__ vx, f
//================================== k-quants //================================== k-quants
extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float * __restrict__ yy) { template<typename dst_t>
static __device__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int i = blockIdx.x; const int i = blockIdx.x;
const block_q2_K * x = (const block_q2_K *) vx; const block_q2_K * x = (const block_q2_K *) vx;
@ -832,7 +835,7 @@ extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, f
const int is = 8*n + l/16; const int is = 8*n + l/16;
const uint8_t q = x[i].qs[32*n + l]; const uint8_t q = x[i].qs[32*n + l];
float * y = yy + i*QK_K + 128*n; dst_t * y = yy + i*QK_K + 128*n;
float dall = __low2half(x[i].dm); float dall = __low2half(x[i].dm);
float dmin = __high2half(x[i].dm); float dmin = __high2half(x[i].dm);
@ -844,7 +847,7 @@ extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, f
const int is = tid/16; // 0 or 1 const int is = tid/16; // 0 or 1
const int il = tid%16; // 0...15 const int il = tid%16; // 0...15
const uint8_t q = x[i].qs[il] >> (2*is); const uint8_t q = x[i].qs[il] >> (2*is);
float * y = yy + i*QK_K + 16*is + il; dst_t * y = yy + i*QK_K + 16*is + il;
float dall = __low2half(x[i].dm); float dall = __low2half(x[i].dm);
float dmin = __high2half(x[i].dm); float dmin = __high2half(x[i].dm);
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
@ -853,7 +856,8 @@ extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, f
} }
extern "C" __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float * __restrict__ yy) { template<typename dst_t>
static __device__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int i = blockIdx.x; const int i = blockIdx.x;
const block_q3_K * x = (const block_q3_K *) vx; const block_q3_K * x = (const block_q3_K *) vx;
@ -877,7 +881,7 @@ extern "C" __global__ void dequantize_block_q3_K(const void * __restrict__ vx, f
float d_all = x[i].d; float d_all = x[i].d;
float dl = d_all * (us - 32); float dl = d_all * (us - 32);
float * y = yy + i*QK_K + 128*n + 32*j; dst_t * y = yy + i*QK_K + 128*n + 32*j;
const uint8_t * q = x[i].qs + 32*n; const uint8_t * q = x[i].qs + 32*n;
const uint8_t * hm = x[i].hmask; const uint8_t * hm = x[i].hmask;
@ -889,7 +893,7 @@ extern "C" __global__ void dequantize_block_q3_K(const void * __restrict__ vx, f
const int im = il/8; // 0...1 const int im = il/8; // 0...1
const int in = il%8; // 0...7 const int in = il%8; // 0...7
float * y = yy + i*QK_K + 16*is + il; dst_t * y = yy + i*QK_K + 16*is + il;
const uint8_t q = x[i].qs[il] >> (2*is); const uint8_t q = x[i].qs[il] >> (2*is);
const uint8_t h = x[i].hmask[in] >> (2*is + im); const uint8_t h = x[i].hmask[in] >> (2*is + im);
@ -917,7 +921,8 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
} }
#endif #endif
extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float * __restrict__ yy) { template<typename dst_t>
static __device__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const block_q4_K * x = (const block_q4_K *) vx; const block_q4_K * x = (const block_q4_K *) vx;
const int i = blockIdx.x; const int i = blockIdx.x;
@ -930,7 +935,7 @@ extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, f
const int is = 2*il; const int is = 2*il;
const int n = 4; const int n = 4;
float * y = yy + i*QK_K + 64*il + n*ir; dst_t * y = yy + i*QK_K + 64*il + n*ir;
const float dall = __low2half(x[i].dm); const float dall = __low2half(x[i].dm);
const float dmin = __high2half(x[i].dm); const float dmin = __high2half(x[i].dm);
@ -949,7 +954,7 @@ extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, f
#else #else
const int tid = threadIdx.x; const int tid = threadIdx.x;
const uint8_t * q = x[i].qs; const uint8_t * q = x[i].qs;
float * y = yy + i*QK_K; dst_t * y = yy + i*QK_K;
const float d = (float)x[i].dm[0]; const float d = (float)x[i].dm[0];
const float m = (float)x[i].dm[1]; const float m = (float)x[i].dm[1];
y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4); y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
@ -957,7 +962,8 @@ extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, f
#endif #endif
} }
extern "C" __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float * __restrict__ yy) { template<typename dst_t>
static __device__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const block_q5_K * x = (const block_q5_K *) vx; const block_q5_K * x = (const block_q5_K *) vx;
const int i = blockIdx.x; const int i = blockIdx.x;
@ -969,7 +975,7 @@ extern "C" __global__ void dequantize_block_q5_K(const void * __restrict__ vx, f
const int ir = tid%16; // ir is in 0...15 const int ir = tid%16; // ir is in 0...15
const int is = 2*il; // is is in 0...6 const int is = 2*il; // is is in 0...6
float * y = yy + i*QK_K + 64*il + 2*ir; dst_t * y = yy + i*QK_K + 64*il + 2*ir;
const float dall = __low2half(x[i].dm); const float dall = __low2half(x[i].dm);
const float dmin = __high2half(x[i].dm); const float dmin = __high2half(x[i].dm);
@ -997,25 +1003,26 @@ extern "C" __global__ void dequantize_block_q5_K(const void * __restrict__ vx, f
const int is = tid/16; // 0 or 1 const int is = tid/16; // 0 or 1
const uint8_t h = x[i].qh[in] >> im; const uint8_t h = x[i].qh[in] >> im;
const float d = x[i].d; const float d = x[i].d;
float * y = yy + i*QK_K + tid; dst_t * y = yy + i*QK_K + tid;
y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16)); y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16)); y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
#endif #endif
} }
extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float * __restrict__ yy) { template<typename dst_t>
static __device__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const block_q6_K * x = (const block_q6_K *) vx; const block_q6_K * x = (const block_q6_K *) vx;
const int i = blockIdx.x; const int64_t i = blockIdx.x;
#if QK_K == 256 #if QK_K == 256
// assume 64 threads - this is very slightly better than the one below // assume 64 threads - this is very slightly better than the one below
const int tid = threadIdx.x; const int64_t tid = threadIdx.x;
const int ip = tid/32; // ip is 0 or 1 const int64_t ip = tid/32; // ip is 0 or 1
const int il = tid - 32*ip; // 0...32 const int64_t il = tid - 32*ip; // 0...32
const int is = 8*ip + il/16; const int64_t is = 8*ip + il/16;
float * y = yy + i*QK_K + 128*ip + il; dst_t * y = yy + i*QK_K + 128*ip + il;
const float d = x[i].d; const float d = x[i].d;
@ -1030,11 +1037,11 @@ extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, f
#else #else
// assume 32 threads // assume 32 threads
const int tid = threadIdx.x; const int64_t tid = threadIdx.x;
const int ip = tid/16; // 0 or 1 const int64_t ip = tid/16; // 0 or 1
const int il = tid - 16*ip; // 0...15 const int64_t il = tid - 16*ip; // 0...15
float * y = yy + i*QK_K + 16*ip + il; dst_t * y = yy + i*QK_K + 16*ip + il;
const float d = x[i].d; const float d = x[i].d;
@ -1047,7 +1054,8 @@ extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, f
#endif #endif
} }
extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) { template<typename dst_t>
static __device__ void dequantize_block_q8_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
const int i = blockIdx.x; const int i = blockIdx.x;
// assume 32 threads // assume 32 threads
@ -1059,7 +1067,7 @@ extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, f
return; return;
} }
float * y = yy + 256*i + 32*ir + 8*il; dst_t * y = yy + 256*i + 32*ir + 8*il;
const block_q8_0 * x = (const block_q8_0 *)vx + ib; const block_q8_0 * x = (const block_q8_0 *)vx + ib;
const float d = __half2float(x->d); const float d = __half2float(x->d);
@ -1071,7 +1079,8 @@ extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, f
} }
} }
extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, float * __restrict__ yy) { template<typename dst_t>
static __device__ void dequantize_block_q8_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const block_q8_K * x = (const block_q8_K *) vx; const block_q8_K * x = (const block_q8_K *) vx;
const int i = blockIdx.x; const int i = blockIdx.x;
@ -1083,7 +1092,7 @@ extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, f
const int ir = tid%8; const int ir = tid%8;
const int n = 8; const int n = 8;
float * y = yy + i*QK_K + 64*il + n*ir; dst_t * y = yy + i*QK_K + 64*il + n*ir;
const int8_t * q = x[i].qs + 64*il + n*ir; const int8_t * q = x[i].qs + 64*il + n*ir;
@ -1098,14 +1107,43 @@ extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, f
#endif #endif
} }
extern "C" __global__ void dequantize_block_q5_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) { template<typename dst_t>
static __device__ void dequantize_block_q5_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
return dequantize_block<QK5_0, QR5_0, dequantize_q5_0>(vx, yy, nb32); return dequantize_block<QK5_0, QR5_0, dequantize_q5_0>(vx, yy, nb32);
} }
extern "C" __global__ void dequantize_block_q5_1(const void * __restrict__ vx, float * __restrict__ yy, int nb32) { template<typename dst_t>
static __device__ void dequantize_block_q5_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
return dequantize_block<QK5_1, QR5_1, dequantize_q5_1>(vx, yy, nb32); return dequantize_block<QK5_1, QR5_1, dequantize_q5_1>(vx, yy, nb32);
} }
#define DEQUANTIZE_K(QNAME) \
extern "C" __global__ void dequantize_block_##QNAME##_f32(const void * __restrict__ vx, float * __restrict__ y) { \
dequantize_block_##QNAME(vx, y); \
} \
extern "C" __global__ void dequantize_block_##QNAME##_f16(const void * __restrict__ vx, half * __restrict__ y) { \
dequantize_block_##QNAME(vx, y); \
} \
#define DEQUANTIZE(QNAME) \
extern "C" __global__ void dequantize_block_##QNAME##_f32(const void * __restrict__ vx, float * __restrict__ y, const int k) { \
dequantize_block_##QNAME(vx, y, k); \
} \
extern "C" __global__ void dequantize_block_##QNAME##_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) { \
dequantize_block_##QNAME(vx, y, k); \
} \
DEQUANTIZE_K(q2_K)
DEQUANTIZE_K(q3_K)
DEQUANTIZE_K(q4_K)
DEQUANTIZE_K(q5_K)
DEQUANTIZE_K(q6_K)
DEQUANTIZE_K(q8_K)
DEQUANTIZE(q4_0)
DEQUANTIZE(q4_1)
DEQUANTIZE(q5_0)
DEQUANTIZE(q5_1)
DEQUANTIZE(q8_0)
template <int qk, int qr, dequantize_kernel_t dequantize_kernel> template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {