mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add the cuda dequantize f16 kernels. (#2137)
* Add the cuda dequantize f16 kernels. * Expose the cuda kernels. * Add some testing + fix. * Test the other cases too. * A few more tests. * Add an environment variable to enable the dequantize f16 + matmul behavior.
This commit is contained in:
@ -2,6 +2,7 @@ use super::{GgmlDType, QStorage};
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
||||
use crate::{CudaDevice, CudaStorage, Result};
|
||||
use half::f16;
|
||||
|
||||
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
|
||||
|
||||
@ -59,7 +60,7 @@ fn quantize_q8_1(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn dequantize(
|
||||
fn dequantize_f32(
|
||||
data: &CudaSlice<u8>,
|
||||
dtype: GgmlDType,
|
||||
elem_count: usize,
|
||||
@ -69,27 +70,27 @@ fn dequantize(
|
||||
|
||||
let nb = (elem_count + 255) / 256;
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
|
||||
GgmlDType::Q5_0 => (
|
||||
"dequantize_block_q5_0",
|
||||
"dequantize_block_q5_0_f32",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q5_1 => (
|
||||
"dequantize_block_q5_1",
|
||||
"dequantize_block_q5_1_f32",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
|
||||
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
|
||||
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
|
||||
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
|
||||
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
|
||||
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
|
||||
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
|
||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f32", false, 32, nb),
|
||||
GgmlDType::Q2K => ("dequantize_block_q2_K_f32", true, 64, nb),
|
||||
GgmlDType::Q3K => ("dequantize_block_q3_K_f32", true, 64, nb),
|
||||
GgmlDType::Q4K => ("dequantize_block_q4_K_f32", true, 32, nb),
|
||||
GgmlDType::Q5K => ("dequantize_block_q5_K_f32", true, 64, nb),
|
||||
GgmlDType::Q6K => ("dequantize_block_q6_K_f32", true, 64, nb),
|
||||
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
@ -116,6 +117,63 @@ fn dequantize(
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
fn dequantize_f16(
|
||||
data: &CudaSlice<u8>,
|
||||
dtype: GgmlDType,
|
||||
elem_count: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let nb = (elem_count + 255) / 256;
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
|
||||
GgmlDType::Q5_0 => (
|
||||
"dequantize_block_q5_0_f16",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q5_1 => (
|
||||
"dequantize_block_q5_1_f16",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f16", false, 32, nb),
|
||||
GgmlDType::Q2K => ("dequantize_block_q2_K_f16", true, 64, nb),
|
||||
GgmlDType::Q3K => ("dequantize_block_q3_K_f16", true, 64, nb),
|
||||
GgmlDType::Q4K => ("dequantize_block_q4_K_f16", true, 32, nb),
|
||||
GgmlDType::Q5K => ("dequantize_block_q5_K_f16", true, 64, nb),
|
||||
GgmlDType::Q6K => ("dequantize_block_q6_K_f16", true, 64, nb),
|
||||
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
|
||||
// See e.g.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (num_blocks as u32, 1, 1),
|
||||
block_dim: (block_dim as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
if is_k {
|
||||
let params = (data, &dst);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
} else {
|
||||
let nb32 = match dtype {
|
||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||
_ => elem_count / 32,
|
||||
};
|
||||
let params = (data, &dst, nb32 as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
}
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
fn dequantize_mul_mat_vec(
|
||||
data: &CudaSlice<u8>,
|
||||
y: &CudaView<f32>,
|
||||
@ -341,7 +399,7 @@ impl QCudaStorage {
|
||||
| GgmlDType::Q8K
|
||||
);
|
||||
if fast_kernel {
|
||||
return dequantize(&self.data, self.dtype, elem_count, self.device());
|
||||
return dequantize_f32(&self.data, self.dtype, elem_count, self.device());
|
||||
}
|
||||
// Run the dequantization on cpu.
|
||||
|
||||
@ -369,6 +427,10 @@ impl QCudaStorage {
|
||||
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
|
||||
}
|
||||
|
||||
pub fn dequantize_f16(&self, elem_count: usize) -> Result<CudaStorage> {
|
||||
dequantize_f16(&self.data, self.dtype, elem_count, self.device())
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
|
||||
// Run the quantization on cpu.
|
||||
let src = match &src.slice {
|
||||
|
@ -24,6 +24,10 @@ impl QCudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn dequantize_f16(&self, _elem_count: usize) -> Result<CudaStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
||||
use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
|
||||
use k_quants::*;
|
||||
use std::borrow::Cow;
|
||||
|
||||
@ -360,9 +360,24 @@ impl QTensor {
|
||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||
let storage = self.storage.dequantize(self.shape.elem_count())?;
|
||||
let none = crate::op::BackpropOp::none();
|
||||
let is_variable = false;
|
||||
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
|
||||
.to_device(device)
|
||||
crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
|
||||
}
|
||||
|
||||
pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
|
||||
// In the CUDA case, we have a specialized kernel as this can be useful for volta
|
||||
// architectures. https://github.com/huggingface/candle/issues/2136
|
||||
match &self.storage {
|
||||
QStorage::Cuda(s) => {
|
||||
let s = s.dequantize_f16(self.shape.elem_count())?;
|
||||
let none = crate::op::BackpropOp::none();
|
||||
crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
|
||||
.to_device(device)
|
||||
}
|
||||
_ => {
|
||||
let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
|
||||
Ok(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
@ -378,6 +393,7 @@ impl QTensor {
|
||||
pub enum QMatMul {
|
||||
QTensor(std::sync::Arc<QTensor>),
|
||||
Tensor(Tensor),
|
||||
TensorF16(Tensor),
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
@ -391,6 +407,17 @@ thread_local! {
|
||||
}
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static DEQUANTIZE_ALL_F16: bool = {
|
||||
match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
|
||||
Ok(s) => {
|
||||
!s.is_empty() && s != "0"
|
||||
},
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
||||
let dequantize = match qtensor.dtype() {
|
||||
@ -400,6 +427,9 @@ impl QMatMul {
|
||||
let t = if dequantize {
|
||||
let tensor = qtensor.dequantize(&qtensor.device())?;
|
||||
Self::Tensor(tensor)
|
||||
} else if DEQUANTIZE_ALL_F16.with(|b| *b) {
|
||||
let tensor = qtensor.dequantize_f16(&qtensor.device())?;
|
||||
Self::TensorF16(tensor)
|
||||
} else {
|
||||
Self::QTensor(qtensor)
|
||||
};
|
||||
@ -486,6 +516,15 @@ impl crate::Module for QMatMul {
|
||||
};
|
||||
xs.matmul(&w)
|
||||
}
|
||||
Self::TensorF16(w) => {
|
||||
let in_dtype = xs.dtype();
|
||||
let w = match *xs.dims() {
|
||||
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||
_ => w.t()?,
|
||||
};
|
||||
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user