More ggml cuda kernels (#1977)

* Add more cuda kernels for quantized matmul.

* Add the vec-dot bits.

* Expose the quantized matmul-vec kernels.

* Also include the quantize-q8-1 kernel.

* Glue code for the q8-1 quantization.

* mm-vec product via q8-1 quantization.

* Add a test.

* Add a mm test.

* Get the test to return some sensible results.

* Also test dmmv.

* Fix the launch params.

* Allow for tweaking the force_dmmv parameter while it's experimental.
This commit is contained in:
Laurent Mazare
2024-04-01 00:15:48 +02:00
committed by GitHub
parent f9954b73ba
commit cd29c7ccd4
3 changed files with 1169 additions and 82 deletions

View File

@ -2,7 +2,7 @@ use super::{GgmlDType, QStorage};
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
use crate::{CudaDevice, CudaStorage, Result};
use cudarc::driver::{CudaSlice, DeviceSlice};
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
pub struct QCudaStorage {
data: CudaSlice<u8>,
@ -10,13 +10,43 @@ pub struct QCudaStorage {
device: CudaDevice,
}
static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(true);
pub fn set_force_dmmv(f: bool) {
FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed)
}
pub const WARP_SIZE: usize = 32;
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
pub const NWARPS_Q4_0_AMPERE: usize = 4;
pub const GGML_CUDA_MMV_X: usize = 32;
pub const GGML_CUDA_MMV_Y: usize = 1;
pub const CUDA_QUANTIZE_BLOCK_SIZE: usize = 256;
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
pub const MATRIX_ROW_PADDING: usize = 512;
fn quantize_q8_1(
src: &CudaView<f32>,
dst: &mut CudaSlice<u8>,
elem_count: usize,
dev: &CudaDevice,
) -> Result<()> {
use cudarc::driver::LaunchAsync;
let kx = elem_count;
let kx_padded = (kx + MATRIX_ROW_PADDING - 1) / MATRIX_ROW_PADDING * MATRIX_ROW_PADDING;
let num_blocks = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, 1, 1),
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
shared_mem_bytes: 0,
};
let params = (src, dst, kx as i32, kx_padded as i32);
unsafe { func.launch(cfg, params) }.w()?;
Ok(())
}
fn dequantize(
data: &CudaSlice<u8>,
@ -60,7 +90,7 @@ fn dequantize(
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
@ -83,9 +113,9 @@ fn dequantize(
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_mut_mal_vec(
fn dequantize_mul_mat_vec(
data: &CudaSlice<u8>,
y: &cudarc::driver::CudaView<f32>,
y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
nrows: usize,
@ -107,7 +137,7 @@ fn dequantize_mut_mal_vec(
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = dev.alloc_zeros::<f32>(nrows).w()?;
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
let block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (block_num_y as u32, 1, 1),
@ -120,6 +150,56 @@ fn dequantize_mut_mal_vec(
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn mul_mat_vec_via_q8_1(
data: &CudaSlice<u8>,
y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
nrows: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
// Start by quantizing y
let ncols_padded = (ncols + MATRIX_ROW_PADDING - 1) / MATRIX_ROW_PADDING * MATRIX_ROW_PADDING;
let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, ncols, dev)?;
let kernel_name = match dtype {
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
GgmlDType::Q4_1 => "mul_mat_vec_q4_1_q8_1_cuda",
GgmlDType::Q5_0 => "mul_mat_vec_q5_0_q8_1_cuda",
GgmlDType::Q5_1 => "mul_mat_vec_q5_1_q8_1_cuda",
GgmlDType::Q8_0 => "mul_mat_vec_q8_0_q8_1_cuda",
GgmlDType::Q2K => "mul_mat_vec_q2_K_q8_1_cuda",
GgmlDType::Q3K => "mul_mat_vec_q3_K_q8_1_cuda",
GgmlDType::Q4K => "mul_mat_vec_q4_K_q8_1_cuda",
GgmlDType::Q5K => "mul_mat_vec_q5_K_q8_1_cuda",
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (nrows as u32, 1, 1),
block_dim: (WARP_SIZE as u32, 4, 1),
shared_mem_bytes: 0,
};
let params = (
data,
&y_q8_1,
&dst,
/* ncols_x */ ncols as i32,
/* nrows_x */ nrows as i32,
/* nrows_y */ ncols as i32,
/* nrows_dst */ nrows as i32,
);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
impl QCudaStorage {
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
let size_in_bytes = el_count * dtype.type_size() / dtype.block_size();
@ -285,8 +365,11 @@ impl QCudaStorage {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
}
let out =
dequantize_mut_mal_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?;
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
} else {
mul_mat_vec_via_q8_1(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
};
let out_shape = if with_batch {
vec![1, 1, nrows]
} else {
@ -341,3 +424,60 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
dtype: T::DTYPE,
}))
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn cuda_quantize_q8_1() -> Result<()> {
let dev = CudaDevice::new(0)?;
let el = 256;
let el_padded = (el + MATRIX_ROW_PADDING - 1) / MATRIX_ROW_PADDING * MATRIX_ROW_PADDING;
let y_size_in_bytes =
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
let y = dev.htod_sync_copy(&vs).w()?;
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, &dev)?;
Ok(())
}
#[test]
fn cuda_mmv_q8_1() -> Result<()> {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_vec_via_q8_1(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* ncols */ ncols,
/* nrows */ 1,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
assert_eq!(vs.len(), 1);
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
// Q8 means 1/256 precision.
assert_eq!(vs[0], 5561664.5);
let cuda_storage = dequantize_mul_mat_vec(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* ncols */ ncols,
/* nrows */ 1,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
assert_eq!(vs.len(), 1);
assert_eq!(vs[0], 5561851.0);
Ok(())
}
}

View File

@ -235,6 +235,10 @@ struct Args {
/// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
#[arg(long)]
gqa: Option<usize>,
/// Use the (experimental) fast cuda kernels.
#[arg(long)]
fast_cuda: bool,
}
impl Args {
@ -341,6 +345,10 @@ fn main() -> anyhow::Result<()> {
use tracing_subscriber::prelude::*;
let args = Args::parse();
#[cfg(feature = "cuda")]
candle::quantized::cuda::set_force_dmmv(!args.fast_cuda);
let temperature = if args.temperature == 0. {
None
} else {

File diff suppressed because it is too large Load Diff