mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Quantized cuda tweaks. (#1981)
* Quantized cuda tweaks. * Add some safety checks. * Factorize the dequantization bits.
This commit is contained in:
@ -1,9 +1,11 @@
|
|||||||
use super::{GgmlDType, QStorage};
|
use super::{GgmlDType, QStorage};
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
||||||
use crate::{CudaDevice, CudaStorage, Result};
|
use crate::{CudaDevice, CudaStorage, Result};
|
||||||
|
|
||||||
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
|
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
pub struct QCudaStorage {
|
pub struct QCudaStorage {
|
||||||
data: CudaSlice<u8>,
|
data: CudaSlice<u8>,
|
||||||
dtype: GgmlDType,
|
dtype: GgmlDType,
|
||||||
@ -26,6 +28,14 @@ pub const CUDA_QUANTIZE_BLOCK_SIZE: usize = 256;
|
|||||||
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
|
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
|
||||||
pub const MATRIX_ROW_PADDING: usize = 512;
|
pub const MATRIX_ROW_PADDING: usize = 512;
|
||||||
|
|
||||||
|
fn ceil_div(p: usize, q: usize) -> usize {
|
||||||
|
(p + q - 1) / q
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pad(p: usize, q: usize) -> usize {
|
||||||
|
ceil_div(p, q) * q
|
||||||
|
}
|
||||||
|
|
||||||
fn quantize_q8_1(
|
fn quantize_q8_1(
|
||||||
src: &CudaView<f32>,
|
src: &CudaView<f32>,
|
||||||
dst: &mut CudaSlice<u8>,
|
dst: &mut CudaSlice<u8>,
|
||||||
@ -35,8 +45,8 @@ fn quantize_q8_1(
|
|||||||
use cudarc::driver::LaunchAsync;
|
use cudarc::driver::LaunchAsync;
|
||||||
|
|
||||||
let kx = elem_count;
|
let kx = elem_count;
|
||||||
let kx_padded = (kx + MATRIX_ROW_PADDING - 1) / MATRIX_ROW_PADDING * MATRIX_ROW_PADDING;
|
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
|
||||||
let num_blocks = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
|
||||||
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
grid_dim: (num_blocks as u32, 1, 1),
|
grid_dim: (num_blocks as u32, 1, 1),
|
||||||
@ -60,26 +70,18 @@ fn dequantize(
|
|||||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
|
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
|
||||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
|
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
|
||||||
GgmlDType::Q5_0 => {
|
GgmlDType::Q5_0 => (
|
||||||
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
|
"dequantize_block_q5_0",
|
||||||
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
|
false,
|
||||||
(
|
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||||
"dequantize_block_q5_0",
|
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||||
false,
|
),
|
||||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
GgmlDType::Q5_1 => (
|
||||||
nb,
|
"dequantize_block_q5_1",
|
||||||
)
|
false,
|
||||||
}
|
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||||
GgmlDType::Q5_1 => {
|
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||||
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
|
),
|
||||||
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
|
|
||||||
(
|
|
||||||
"dequantize_block_q5_1",
|
|
||||||
false,
|
|
||||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
|
||||||
nb,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
|
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
|
||||||
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
|
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
|
||||||
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
|
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
|
||||||
@ -123,6 +125,13 @@ fn dequantize_mul_mat_vec(
|
|||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
use cudarc::driver::LaunchAsync;
|
use cudarc::driver::LaunchAsync;
|
||||||
|
|
||||||
|
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
||||||
|
if data_elems < ncols * nrows {
|
||||||
|
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||||
|
}
|
||||||
|
if y.len() != ncols {
|
||||||
|
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
||||||
|
}
|
||||||
let kernel_name = match dtype {
|
let kernel_name = match dtype {
|
||||||
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
|
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
|
||||||
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
|
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
|
||||||
@ -138,7 +147,7 @@ fn dequantize_mul_mat_vec(
|
|||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
||||||
let block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
grid_dim: (block_num_y as u32, 1, 1),
|
grid_dim: (block_num_y as u32, 1, 1),
|
||||||
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
|
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
|
||||||
@ -160,8 +169,15 @@ fn mul_mat_vec_via_q8_1(
|
|||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
use cudarc::driver::LaunchAsync;
|
use cudarc::driver::LaunchAsync;
|
||||||
|
|
||||||
|
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
||||||
|
if data_elems < ncols * nrows {
|
||||||
|
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||||
|
}
|
||||||
|
if y.len() != ncols {
|
||||||
|
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
||||||
|
}
|
||||||
// Start by quantizing y
|
// Start by quantizing y
|
||||||
let ncols_padded = (ncols + MATRIX_ROW_PADDING - 1) / MATRIX_ROW_PADDING * MATRIX_ROW_PADDING;
|
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
||||||
let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||||
quantize_q8_1(y, &mut y_q8_1, ncols, dev)?;
|
quantize_q8_1(y, &mut y_q8_1, ncols, dev)?;
|
||||||
@ -202,7 +218,7 @@ fn mul_mat_vec_via_q8_1(
|
|||||||
|
|
||||||
impl QCudaStorage {
|
impl QCudaStorage {
|
||||||
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
|
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
|
||||||
let size_in_bytes = el_count * dtype.type_size() / dtype.block_size();
|
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
|
||||||
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
|
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
|
||||||
Ok(QCudaStorage {
|
Ok(QCudaStorage {
|
||||||
data,
|
data,
|
||||||
@ -220,6 +236,12 @@ impl QCudaStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
|
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
|
||||||
|
fn deq<T: GgmlType>(buffer: &[u8], n: usize, dst: &mut [f32]) -> Result<()> {
|
||||||
|
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
|
||||||
|
let vec = slice.to_vec();
|
||||||
|
T::to_float(&vec, dst)
|
||||||
|
}
|
||||||
|
|
||||||
let fast_kernel = matches!(
|
let fast_kernel = matches!(
|
||||||
self.dtype,
|
self.dtype,
|
||||||
GgmlDType::Q4_0
|
GgmlDType::Q4_0
|
||||||
@ -238,69 +260,25 @@ impl QCudaStorage {
|
|||||||
return dequantize(&self.data, self.dtype, elem_count, self.device());
|
return dequantize(&self.data, self.dtype, elem_count, self.device());
|
||||||
}
|
}
|
||||||
// Run the dequantization on cpu.
|
// Run the dequantization on cpu.
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
|
|
||||||
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
|
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
|
||||||
let mut out = vec![0.0; elem_count];
|
let mut out = vec![0.0; elem_count];
|
||||||
let block_len = elem_count / self.dtype.block_size();
|
let block_len = elem_count / self.dtype.block_size();
|
||||||
match self.dtype {
|
match self.dtype {
|
||||||
GgmlDType::F32 => {
|
GgmlDType::F32 => deq::<f32>(&buffer, block_len, &mut out)?,
|
||||||
let slice =
|
GgmlDType::F16 => deq::<half::f16>(&buffer, block_len, &mut out)?,
|
||||||
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const f32, block_len) };
|
GgmlDType::Q4_0 => deq::<crate::quantized::BlockQ4_0>(&buffer, block_len, &mut out)?,
|
||||||
out.copy_from_slice(slice)
|
GgmlDType::Q4_1 => deq::<crate::quantized::BlockQ4_1>(&buffer, block_len, &mut out)?,
|
||||||
}
|
GgmlDType::Q5_0 => deq::<crate::quantized::BlockQ5_0>(&buffer, block_len, &mut out)?,
|
||||||
GgmlDType::F16 => {
|
GgmlDType::Q5_1 => deq::<crate::quantized::BlockQ5_1>(&buffer, block_len, &mut out)?,
|
||||||
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
|
GgmlDType::Q8_0 => deq::<crate::quantized::BlockQ8_0>(&buffer, block_len, &mut out)?,
|
||||||
half::f16::to_float(&vec, &mut out)?;
|
GgmlDType::Q8_1 => deq::<crate::quantized::BlockQ8_1>(&buffer, block_len, &mut out)?,
|
||||||
}
|
GgmlDType::Q2K => deq::<crate::quantized::BlockQ2K>(&buffer, block_len, &mut out)?,
|
||||||
GgmlDType::Q4_0 => {
|
GgmlDType::Q3K => deq::<crate::quantized::BlockQ3K>(&buffer, block_len, &mut out)?,
|
||||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
|
GgmlDType::Q4K => deq::<crate::quantized::BlockQ4K>(&buffer, block_len, &mut out)?,
|
||||||
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
GgmlDType::Q5K => deq::<crate::quantized::BlockQ5K>(&buffer, block_len, &mut out)?,
|
||||||
}
|
GgmlDType::Q6K => deq::<crate::quantized::BlockQ6K>(&buffer, block_len, &mut out)?,
|
||||||
GgmlDType::Q4_1 => {
|
GgmlDType::Q8K => deq::<crate::quantized::BlockQ8K>(&buffer, block_len, &mut out)?,
|
||||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q5_0 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q5_1 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q8_0 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q8_1 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q2K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q3K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q4K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q5K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q6K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q8K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
|
|
||||||
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
self.device
|
self.device
|
||||||
@ -405,11 +383,6 @@ impl QCudaStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_to_vec<T: Clone>(buffer: &[u8], n: usize) -> Vec<T> {
|
|
||||||
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
|
|
||||||
slice.to_vec()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||||
device: &CudaDevice,
|
device: &CudaDevice,
|
||||||
data: &[T],
|
data: &[T],
|
||||||
@ -433,7 +406,7 @@ mod test {
|
|||||||
fn cuda_quantize_q8_1() -> Result<()> {
|
fn cuda_quantize_q8_1() -> Result<()> {
|
||||||
let dev = CudaDevice::new(0)?;
|
let dev = CudaDevice::new(0)?;
|
||||||
let el = 256;
|
let el = 256;
|
||||||
let el_padded = (el + MATRIX_ROW_PADDING - 1) / MATRIX_ROW_PADDING * MATRIX_ROW_PADDING;
|
let el_padded = pad(el, MATRIX_ROW_PADDING);
|
||||||
let y_size_in_bytes =
|
let y_size_in_bytes =
|
||||||
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||||
|
Reference in New Issue
Block a user