mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the cuda dequantize f16 kernels. (#2137)
* Add the cuda dequantize f16 kernels. * Expose the cuda kernels. * Add some testing + fix. * Test the other cases too. * A few more tests. * Add an environment variable to enable the dequantize f16 + matmul behavior.
This commit is contained in:
@ -2,6 +2,7 @@ use super::{GgmlDType, QStorage};
|
|||||||
use crate::quantized::k_quants::GgmlType;
|
use crate::quantized::k_quants::GgmlType;
|
||||||
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
||||||
use crate::{CudaDevice, CudaStorage, Result};
|
use crate::{CudaDevice, CudaStorage, Result};
|
||||||
|
use half::f16;
|
||||||
|
|
||||||
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
|
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
|
||||||
|
|
||||||
@ -59,7 +60,7 @@ fn quantize_q8_1(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn dequantize(
|
fn dequantize_f32(
|
||||||
data: &CudaSlice<u8>,
|
data: &CudaSlice<u8>,
|
||||||
dtype: GgmlDType,
|
dtype: GgmlDType,
|
||||||
elem_count: usize,
|
elem_count: usize,
|
||||||
@ -69,27 +70,27 @@ fn dequantize(
|
|||||||
|
|
||||||
let nb = (elem_count + 255) / 256;
|
let nb = (elem_count + 255) / 256;
|
||||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
|
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
||||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
|
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
|
||||||
GgmlDType::Q5_0 => (
|
GgmlDType::Q5_0 => (
|
||||||
"dequantize_block_q5_0",
|
"dequantize_block_q5_0_f32",
|
||||||
false,
|
false,
|
||||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||||
),
|
),
|
||||||
GgmlDType::Q5_1 => (
|
GgmlDType::Q5_1 => (
|
||||||
"dequantize_block_q5_1",
|
"dequantize_block_q5_1_f32",
|
||||||
false,
|
false,
|
||||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||||
),
|
),
|
||||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
|
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f32", false, 32, nb),
|
||||||
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
|
GgmlDType::Q2K => ("dequantize_block_q2_K_f32", true, 64, nb),
|
||||||
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
|
GgmlDType::Q3K => ("dequantize_block_q3_K_f32", true, 64, nb),
|
||||||
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
|
GgmlDType::Q4K => ("dequantize_block_q4_K_f32", true, 32, nb),
|
||||||
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
|
GgmlDType::Q5K => ("dequantize_block_q5_K_f32", true, 64, nb),
|
||||||
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
|
GgmlDType::Q6K => ("dequantize_block_q6_K_f32", true, 64, nb),
|
||||||
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
|
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
|
||||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||||
@ -116,6 +117,63 @@ fn dequantize(
|
|||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn dequantize_f16(
|
||||||
|
data: &CudaSlice<u8>,
|
||||||
|
dtype: GgmlDType,
|
||||||
|
elem_count: usize,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaStorage> {
|
||||||
|
use cudarc::driver::LaunchAsync;
|
||||||
|
|
||||||
|
let nb = (elem_count + 255) / 256;
|
||||||
|
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||||
|
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
||||||
|
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
|
||||||
|
GgmlDType::Q5_0 => (
|
||||||
|
"dequantize_block_q5_0_f16",
|
||||||
|
false,
|
||||||
|
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||||
|
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||||
|
),
|
||||||
|
GgmlDType::Q5_1 => (
|
||||||
|
"dequantize_block_q5_1_f16",
|
||||||
|
false,
|
||||||
|
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||||
|
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||||
|
),
|
||||||
|
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f16", false, 32, nb),
|
||||||
|
GgmlDType::Q2K => ("dequantize_block_q2_K_f16", true, 64, nb),
|
||||||
|
GgmlDType::Q3K => ("dequantize_block_q3_K_f16", true, 64, nb),
|
||||||
|
GgmlDType::Q4K => ("dequantize_block_q4_K_f16", true, 32, nb),
|
||||||
|
GgmlDType::Q5K => ("dequantize_block_q5_K_f16", true, 64, nb),
|
||||||
|
GgmlDType::Q6K => ("dequantize_block_q6_K_f16", true, 64, nb),
|
||||||
|
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
|
||||||
|
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||||
|
};
|
||||||
|
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||||
|
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
|
||||||
|
// See e.g.
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||||
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
|
grid_dim: (num_blocks as u32, 1, 1),
|
||||||
|
block_dim: (block_dim as u32, 1, 1),
|
||||||
|
shared_mem_bytes: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
if is_k {
|
||||||
|
let params = (data, &dst);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
} else {
|
||||||
|
let nb32 = match dtype {
|
||||||
|
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||||
|
_ => elem_count / 32,
|
||||||
|
};
|
||||||
|
let params = (data, &dst, nb32 as i32);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
}
|
||||||
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
fn dequantize_mul_mat_vec(
|
fn dequantize_mul_mat_vec(
|
||||||
data: &CudaSlice<u8>,
|
data: &CudaSlice<u8>,
|
||||||
y: &CudaView<f32>,
|
y: &CudaView<f32>,
|
||||||
@ -341,7 +399,7 @@ impl QCudaStorage {
|
|||||||
| GgmlDType::Q8K
|
| GgmlDType::Q8K
|
||||||
);
|
);
|
||||||
if fast_kernel {
|
if fast_kernel {
|
||||||
return dequantize(&self.data, self.dtype, elem_count, self.device());
|
return dequantize_f32(&self.data, self.dtype, elem_count, self.device());
|
||||||
}
|
}
|
||||||
// Run the dequantization on cpu.
|
// Run the dequantization on cpu.
|
||||||
|
|
||||||
@ -369,6 +427,10 @@ impl QCudaStorage {
|
|||||||
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
|
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn dequantize_f16(&self, elem_count: usize) -> Result<CudaStorage> {
|
||||||
|
dequantize_f16(&self.data, self.dtype, elem_count, self.device())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
|
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
|
||||||
// Run the quantization on cpu.
|
// Run the quantization on cpu.
|
||||||
let src = match &src.slice {
|
let src = match &src.slice {
|
||||||
|
@ -24,6 +24,10 @@ impl QCudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn dequantize_f16(&self, _elem_count: usize) -> Result<CudaStorage> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
|
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
|
||||||
use k_quants::*;
|
use k_quants::*;
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
|
|
||||||
@ -360,10 +360,25 @@ impl QTensor {
|
|||||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||||
let storage = self.storage.dequantize(self.shape.elem_count())?;
|
let storage = self.storage.dequantize(self.shape.elem_count())?;
|
||||||
let none = crate::op::BackpropOp::none();
|
let none = crate::op::BackpropOp::none();
|
||||||
let is_variable = false;
|
crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
|
||||||
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
|
}
|
||||||
|
|
||||||
|
pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
|
||||||
|
// In the CUDA case, we have a specialized kernel as this can be useful for volta
|
||||||
|
// architectures. https://github.com/huggingface/candle/issues/2136
|
||||||
|
match &self.storage {
|
||||||
|
QStorage::Cuda(s) => {
|
||||||
|
let s = s.dequantize_f16(self.shape.elem_count())?;
|
||||||
|
let none = crate::op::BackpropOp::none();
|
||||||
|
crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
|
||||||
.to_device(device)
|
.to_device(device)
|
||||||
}
|
}
|
||||||
|
_ => {
|
||||||
|
let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
|
||||||
|
Ok(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn storage_size_in_bytes(&self) -> usize {
|
pub fn storage_size_in_bytes(&self) -> usize {
|
||||||
self.storage.size_in_bytes()
|
self.storage.size_in_bytes()
|
||||||
@ -378,6 +393,7 @@ impl QTensor {
|
|||||||
pub enum QMatMul {
|
pub enum QMatMul {
|
||||||
QTensor(std::sync::Arc<QTensor>),
|
QTensor(std::sync::Arc<QTensor>),
|
||||||
Tensor(Tensor),
|
Tensor(Tensor),
|
||||||
|
TensorF16(Tensor),
|
||||||
}
|
}
|
||||||
|
|
||||||
thread_local! {
|
thread_local! {
|
||||||
@ -391,6 +407,17 @@ thread_local! {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
thread_local! {
|
||||||
|
static DEQUANTIZE_ALL_F16: bool = {
|
||||||
|
match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
|
||||||
|
Ok(s) => {
|
||||||
|
!s.is_empty() && s != "0"
|
||||||
|
},
|
||||||
|
Err(_) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl QMatMul {
|
impl QMatMul {
|
||||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
||||||
let dequantize = match qtensor.dtype() {
|
let dequantize = match qtensor.dtype() {
|
||||||
@ -400,6 +427,9 @@ impl QMatMul {
|
|||||||
let t = if dequantize {
|
let t = if dequantize {
|
||||||
let tensor = qtensor.dequantize(&qtensor.device())?;
|
let tensor = qtensor.dequantize(&qtensor.device())?;
|
||||||
Self::Tensor(tensor)
|
Self::Tensor(tensor)
|
||||||
|
} else if DEQUANTIZE_ALL_F16.with(|b| *b) {
|
||||||
|
let tensor = qtensor.dequantize_f16(&qtensor.device())?;
|
||||||
|
Self::TensorF16(tensor)
|
||||||
} else {
|
} else {
|
||||||
Self::QTensor(qtensor)
|
Self::QTensor(qtensor)
|
||||||
};
|
};
|
||||||
@ -486,6 +516,15 @@ impl crate::Module for QMatMul {
|
|||||||
};
|
};
|
||||||
xs.matmul(&w)
|
xs.matmul(&w)
|
||||||
}
|
}
|
||||||
|
Self::TensorF16(w) => {
|
||||||
|
let in_dtype = xs.dtype();
|
||||||
|
let w = match *xs.dims() {
|
||||||
|
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||||
|
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||||
|
_ => w.t()?,
|
||||||
|
};
|
||||||
|
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@ use candle_core::{
|
|||||||
quantized::{self, GgmlDType},
|
quantized::{self, GgmlDType},
|
||||||
test_device,
|
test_device,
|
||||||
test_utils::to_vec2_round,
|
test_utils::to_vec2_round,
|
||||||
Device, IndexOp, Module, Result, Tensor,
|
DType, Device, IndexOp, Module, Result, Tensor,
|
||||||
};
|
};
|
||||||
use quantized::{k_quants, GgmlType};
|
use quantized::{k_quants, GgmlType};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
@ -225,6 +225,13 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
|
|||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?;
|
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
dst.to_vec1::<f32>()?,
|
dst.to_vec1::<f32>()?,
|
||||||
&[
|
&[
|
||||||
@ -251,6 +258,13 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
|
|||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
round_vector(&dst.to_vec1::<f32>()?),
|
round_vector(&dst.to_vec1::<f32>()?),
|
||||||
&[
|
&[
|
||||||
@ -277,6 +291,13 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
|
|||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
round_vector(&dst.to_vec1::<f32>()?),
|
round_vector(&dst.to_vec1::<f32>()?),
|
||||||
&[
|
&[
|
||||||
@ -303,6 +324,13 @@ fn quantize_q5_1(device: &Device) -> Result<()> {
|
|||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
round_vector(&dst.to_vec1::<f32>()?),
|
round_vector(&dst.to_vec1::<f32>()?),
|
||||||
&[
|
&[
|
||||||
@ -387,6 +415,13 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
|
|||||||
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
|
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
|
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
|
||||||
if error > max_error {
|
if error > max_error {
|
||||||
bail!(
|
bail!(
|
||||||
@ -404,6 +439,13 @@ fn quantize_q2k(device: &Device) -> Result<()> {
|
|||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let src = src.to_vec1::<f32>()?;
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
@ -423,6 +465,13 @@ fn quantize_q2k(device: &Device) -> Result<()> {
|
|||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||||
|
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
@ -437,6 +486,13 @@ fn quantize_q3k(device: &Device) -> Result<()> {
|
|||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let src = src.to_vec1::<f32>()?;
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
@ -456,6 +512,13 @@ fn quantize_q3k(device: &Device) -> Result<()> {
|
|||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||||
|
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
@ -470,6 +533,13 @@ fn quantize_q4k(device: &Device) -> Result<()> {
|
|||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let src = src.to_vec1::<f32>()?;
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
@ -489,6 +559,13 @@ fn quantize_q4k(device: &Device) -> Result<()> {
|
|||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||||
|
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
@ -503,6 +580,13 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
|||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let src = src.to_vec1::<f32>()?;
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
@ -522,6 +606,13 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
|||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||||
|
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
@ -536,6 +627,13 @@ fn quantize_q6k(device: &Device) -> Result<()> {
|
|||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let src = src.to_vec1::<f32>()?;
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
@ -555,6 +653,13 @@ fn quantize_q6k(device: &Device) -> Result<()> {
|
|||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||||
|
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
@ -569,6 +674,13 @@ fn quantize_q8k(device: &Device) -> Result<()> {
|
|||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
let dst = quant.dequantize(device)?;
|
let dst = quant.dequantize(device)?;
|
||||||
|
let dst_f16 = quant.dequantize_f16(device)?;
|
||||||
|
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let src = src.to_vec1::<f32>()?;
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let dst = dst.to_vec1::<f32>()?;
|
||||||
@ -588,6 +700,13 @@ fn quantize_q8k(device: &Device) -> Result<()> {
|
|||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
let dst_big = quant_big.dequantize(device)?;
|
||||||
|
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||||
|
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
let src_big = src_big.to_vec1::<f32>()?;
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||||
|
@ -765,20 +765,21 @@ static __device__ void dequantize_block(const void * __restrict__ vx, dst_t * __
|
|||||||
y[iybs + iqs + y_offset] = v.y;
|
y[iybs + iqs + y_offset] = v.y;
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void dequantize_block_q4_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
|
template<typename dst_t>
|
||||||
|
static __device__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
|
|
||||||
// assume 32 threads
|
// assume 32 threads
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
const int il = tid/8;
|
const int il = tid/8;
|
||||||
const int ir = tid%8;
|
const int ir = tid%8;
|
||||||
const int ib = 8*i + ir;
|
const int64_t ib = 8*i + ir;
|
||||||
if (ib >= nb32) {
|
if (ib >= nb32) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float * y = yy + 256*i + 32*ir + 4*il;
|
dst_t * y = yy + 256*i + 32*ir + 4*il;
|
||||||
|
|
||||||
const block_q4_0 * x = (const block_q4_0 *)vx + ib;
|
const block_q4_0 * x = (const block_q4_0 *)vx + ib;
|
||||||
const float d = __half2float(x->d);
|
const float d = __half2float(x->d);
|
||||||
@ -792,20 +793,21 @@ extern "C" __global__ void dequantize_block_q4_0(const void * __restrict__ vx, f
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void dequantize_block_q4_1(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
|
template<typename dst_t>
|
||||||
|
static __device__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
|
|
||||||
// assume 32 threads
|
// assume 32 threads
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
const int il = tid/8;
|
const int il = tid/8;
|
||||||
const int ir = tid%8;
|
const int ir = tid%8;
|
||||||
const int ib = 8*i + ir;
|
const int64_t ib = 8*i + ir;
|
||||||
if (ib >= nb32) {
|
if (ib >= nb32) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float * y = yy + 256*i + 32*ir + 4*il;
|
dst_t * y = yy + 256*i + 32*ir + 4*il;
|
||||||
|
|
||||||
const block_q4_1 * x = (const block_q4_1 *)vx + ib;
|
const block_q4_1 * x = (const block_q4_1 *)vx + ib;
|
||||||
const float2 d = __half22float2(x->dm);
|
const float2 d = __half22float2(x->dm);
|
||||||
@ -820,7 +822,8 @@ extern "C" __global__ void dequantize_block_q4_1(const void * __restrict__ vx, f
|
|||||||
|
|
||||||
//================================== k-quants
|
//================================== k-quants
|
||||||
|
|
||||||
extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
template<typename dst_t>
|
||||||
|
static __device__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int i = blockIdx.x;
|
||||||
const block_q2_K * x = (const block_q2_K *) vx;
|
const block_q2_K * x = (const block_q2_K *) vx;
|
||||||
@ -832,7 +835,7 @@ extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, f
|
|||||||
const int is = 8*n + l/16;
|
const int is = 8*n + l/16;
|
||||||
|
|
||||||
const uint8_t q = x[i].qs[32*n + l];
|
const uint8_t q = x[i].qs[32*n + l];
|
||||||
float * y = yy + i*QK_K + 128*n;
|
dst_t * y = yy + i*QK_K + 128*n;
|
||||||
|
|
||||||
float dall = __low2half(x[i].dm);
|
float dall = __low2half(x[i].dm);
|
||||||
float dmin = __high2half(x[i].dm);
|
float dmin = __high2half(x[i].dm);
|
||||||
@ -844,7 +847,7 @@ extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, f
|
|||||||
const int is = tid/16; // 0 or 1
|
const int is = tid/16; // 0 or 1
|
||||||
const int il = tid%16; // 0...15
|
const int il = tid%16; // 0...15
|
||||||
const uint8_t q = x[i].qs[il] >> (2*is);
|
const uint8_t q = x[i].qs[il] >> (2*is);
|
||||||
float * y = yy + i*QK_K + 16*is + il;
|
dst_t * y = yy + i*QK_K + 16*is + il;
|
||||||
float dall = __low2half(x[i].dm);
|
float dall = __low2half(x[i].dm);
|
||||||
float dmin = __high2half(x[i].dm);
|
float dmin = __high2half(x[i].dm);
|
||||||
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
|
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
|
||||||
@ -853,7 +856,8 @@ extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, f
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
template<typename dst_t>
|
||||||
|
static __device__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int i = blockIdx.x;
|
||||||
const block_q3_K * x = (const block_q3_K *) vx;
|
const block_q3_K * x = (const block_q3_K *) vx;
|
||||||
@ -877,7 +881,7 @@ extern "C" __global__ void dequantize_block_q3_K(const void * __restrict__ vx, f
|
|||||||
float d_all = x[i].d;
|
float d_all = x[i].d;
|
||||||
float dl = d_all * (us - 32);
|
float dl = d_all * (us - 32);
|
||||||
|
|
||||||
float * y = yy + i*QK_K + 128*n + 32*j;
|
dst_t * y = yy + i*QK_K + 128*n + 32*j;
|
||||||
const uint8_t * q = x[i].qs + 32*n;
|
const uint8_t * q = x[i].qs + 32*n;
|
||||||
const uint8_t * hm = x[i].hmask;
|
const uint8_t * hm = x[i].hmask;
|
||||||
|
|
||||||
@ -889,7 +893,7 @@ extern "C" __global__ void dequantize_block_q3_K(const void * __restrict__ vx, f
|
|||||||
const int im = il/8; // 0...1
|
const int im = il/8; // 0...1
|
||||||
const int in = il%8; // 0...7
|
const int in = il%8; // 0...7
|
||||||
|
|
||||||
float * y = yy + i*QK_K + 16*is + il;
|
dst_t * y = yy + i*QK_K + 16*is + il;
|
||||||
|
|
||||||
const uint8_t q = x[i].qs[il] >> (2*is);
|
const uint8_t q = x[i].qs[il] >> (2*is);
|
||||||
const uint8_t h = x[i].hmask[in] >> (2*is + im);
|
const uint8_t h = x[i].hmask[in] >> (2*is + im);
|
||||||
@ -917,7 +921,8 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
template<typename dst_t>
|
||||||
|
static __device__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
const block_q4_K * x = (const block_q4_K *) vx;
|
const block_q4_K * x = (const block_q4_K *) vx;
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int i = blockIdx.x;
|
||||||
@ -930,7 +935,7 @@ extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, f
|
|||||||
const int is = 2*il;
|
const int is = 2*il;
|
||||||
const int n = 4;
|
const int n = 4;
|
||||||
|
|
||||||
float * y = yy + i*QK_K + 64*il + n*ir;
|
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
||||||
|
|
||||||
const float dall = __low2half(x[i].dm);
|
const float dall = __low2half(x[i].dm);
|
||||||
const float dmin = __high2half(x[i].dm);
|
const float dmin = __high2half(x[i].dm);
|
||||||
@ -949,7 +954,7 @@ extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, f
|
|||||||
#else
|
#else
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
const uint8_t * q = x[i].qs;
|
const uint8_t * q = x[i].qs;
|
||||||
float * y = yy + i*QK_K;
|
dst_t * y = yy + i*QK_K;
|
||||||
const float d = (float)x[i].dm[0];
|
const float d = (float)x[i].dm[0];
|
||||||
const float m = (float)x[i].dm[1];
|
const float m = (float)x[i].dm[1];
|
||||||
y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
|
y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
|
||||||
@ -957,7 +962,8 @@ extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, f
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
template<typename dst_t>
|
||||||
|
static __device__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
const block_q5_K * x = (const block_q5_K *) vx;
|
const block_q5_K * x = (const block_q5_K *) vx;
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int i = blockIdx.x;
|
||||||
@ -969,7 +975,7 @@ extern "C" __global__ void dequantize_block_q5_K(const void * __restrict__ vx, f
|
|||||||
const int ir = tid%16; // ir is in 0...15
|
const int ir = tid%16; // ir is in 0...15
|
||||||
const int is = 2*il; // is is in 0...6
|
const int is = 2*il; // is is in 0...6
|
||||||
|
|
||||||
float * y = yy + i*QK_K + 64*il + 2*ir;
|
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
|
||||||
|
|
||||||
const float dall = __low2half(x[i].dm);
|
const float dall = __low2half(x[i].dm);
|
||||||
const float dmin = __high2half(x[i].dm);
|
const float dmin = __high2half(x[i].dm);
|
||||||
@ -997,25 +1003,26 @@ extern "C" __global__ void dequantize_block_q5_K(const void * __restrict__ vx, f
|
|||||||
const int is = tid/16; // 0 or 1
|
const int is = tid/16; // 0 or 1
|
||||||
const uint8_t h = x[i].qh[in] >> im;
|
const uint8_t h = x[i].qh[in] >> im;
|
||||||
const float d = x[i].d;
|
const float d = x[i].d;
|
||||||
float * y = yy + i*QK_K + tid;
|
dst_t * y = yy + i*QK_K + tid;
|
||||||
y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
|
y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
|
||||||
y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
|
y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
template<typename dst_t>
|
||||||
|
static __device__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
const block_q6_K * x = (const block_q6_K *) vx;
|
const block_q6_K * x = (const block_q6_K *) vx;
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
|
|
||||||
// assume 64 threads - this is very slightly better than the one below
|
// assume 64 threads - this is very slightly better than the one below
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const int ip = tid/32; // ip is 0 or 1
|
const int64_t ip = tid/32; // ip is 0 or 1
|
||||||
const int il = tid - 32*ip; // 0...32
|
const int64_t il = tid - 32*ip; // 0...32
|
||||||
const int is = 8*ip + il/16;
|
const int64_t is = 8*ip + il/16;
|
||||||
|
|
||||||
float * y = yy + i*QK_K + 128*ip + il;
|
dst_t * y = yy + i*QK_K + 128*ip + il;
|
||||||
|
|
||||||
const float d = x[i].d;
|
const float d = x[i].d;
|
||||||
|
|
||||||
@ -1030,11 +1037,11 @@ extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, f
|
|||||||
#else
|
#else
|
||||||
|
|
||||||
// assume 32 threads
|
// assume 32 threads
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const int ip = tid/16; // 0 or 1
|
const int64_t ip = tid/16; // 0 or 1
|
||||||
const int il = tid - 16*ip; // 0...15
|
const int64_t il = tid - 16*ip; // 0...15
|
||||||
|
|
||||||
float * y = yy + i*QK_K + 16*ip + il;
|
dst_t * y = yy + i*QK_K + 16*ip + il;
|
||||||
|
|
||||||
const float d = x[i].d;
|
const float d = x[i].d;
|
||||||
|
|
||||||
@ -1047,7 +1054,8 @@ extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, f
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
|
template<typename dst_t>
|
||||||
|
static __device__ void dequantize_block_q8_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
||||||
const int i = blockIdx.x;
|
const int i = blockIdx.x;
|
||||||
|
|
||||||
// assume 32 threads
|
// assume 32 threads
|
||||||
@ -1059,7 +1067,7 @@ extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, f
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float * y = yy + 256*i + 32*ir + 8*il;
|
dst_t * y = yy + 256*i + 32*ir + 8*il;
|
||||||
|
|
||||||
const block_q8_0 * x = (const block_q8_0 *)vx + ib;
|
const block_q8_0 * x = (const block_q8_0 *)vx + ib;
|
||||||
const float d = __half2float(x->d);
|
const float d = __half2float(x->d);
|
||||||
@ -1071,7 +1079,8 @@ extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, f
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
template<typename dst_t>
|
||||||
|
static __device__ void dequantize_block_q8_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
const block_q8_K * x = (const block_q8_K *) vx;
|
const block_q8_K * x = (const block_q8_K *) vx;
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int i = blockIdx.x;
|
||||||
@ -1083,7 +1092,7 @@ extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, f
|
|||||||
const int ir = tid%8;
|
const int ir = tid%8;
|
||||||
const int n = 8;
|
const int n = 8;
|
||||||
|
|
||||||
float * y = yy + i*QK_K + 64*il + n*ir;
|
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
||||||
|
|
||||||
const int8_t * q = x[i].qs + 64*il + n*ir;
|
const int8_t * q = x[i].qs + 64*il + n*ir;
|
||||||
|
|
||||||
@ -1098,14 +1107,43 @@ extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, f
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void dequantize_block_q5_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
|
template<typename dst_t>
|
||||||
|
static __device__ void dequantize_block_q5_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
||||||
return dequantize_block<QK5_0, QR5_0, dequantize_q5_0>(vx, yy, nb32);
|
return dequantize_block<QK5_0, QR5_0, dequantize_q5_0>(vx, yy, nb32);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void dequantize_block_q5_1(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
|
template<typename dst_t>
|
||||||
|
static __device__ void dequantize_block_q5_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
||||||
return dequantize_block<QK5_1, QR5_1, dequantize_q5_1>(vx, yy, nb32);
|
return dequantize_block<QK5_1, QR5_1, dequantize_q5_1>(vx, yy, nb32);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define DEQUANTIZE_K(QNAME) \
|
||||||
|
extern "C" __global__ void dequantize_block_##QNAME##_f32(const void * __restrict__ vx, float * __restrict__ y) { \
|
||||||
|
dequantize_block_##QNAME(vx, y); \
|
||||||
|
} \
|
||||||
|
extern "C" __global__ void dequantize_block_##QNAME##_f16(const void * __restrict__ vx, half * __restrict__ y) { \
|
||||||
|
dequantize_block_##QNAME(vx, y); \
|
||||||
|
} \
|
||||||
|
|
||||||
|
#define DEQUANTIZE(QNAME) \
|
||||||
|
extern "C" __global__ void dequantize_block_##QNAME##_f32(const void * __restrict__ vx, float * __restrict__ y, const int k) { \
|
||||||
|
dequantize_block_##QNAME(vx, y, k); \
|
||||||
|
} \
|
||||||
|
extern "C" __global__ void dequantize_block_##QNAME##_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) { \
|
||||||
|
dequantize_block_##QNAME(vx, y, k); \
|
||||||
|
} \
|
||||||
|
|
||||||
|
DEQUANTIZE_K(q2_K)
|
||||||
|
DEQUANTIZE_K(q3_K)
|
||||||
|
DEQUANTIZE_K(q4_K)
|
||||||
|
DEQUANTIZE_K(q5_K)
|
||||||
|
DEQUANTIZE_K(q6_K)
|
||||||
|
DEQUANTIZE_K(q8_K)
|
||||||
|
DEQUANTIZE(q4_0)
|
||||||
|
DEQUANTIZE(q4_1)
|
||||||
|
DEQUANTIZE(q5_0)
|
||||||
|
DEQUANTIZE(q5_1)
|
||||||
|
DEQUANTIZE(q8_0)
|
||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||||
static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
|
static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||||
|
Reference in New Issue
Block a user