mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add the mmv kernels for small batch sizes. (#2075)
* Add the mmv kernels for smaller sizes. * Support more mmv kernels. * Use the new kernels. * Fix the call. * Silly fix. * Improve the testing. * Fix for dmmv. * Add another dedicated test for the batching mmv.
This commit is contained in:
@ -166,6 +166,7 @@ fn mul_mat_vec_via_q8_1(
|
||||
dtype: GgmlDType,
|
||||
ncols: usize,
|
||||
nrows: usize,
|
||||
b_size: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
@ -174,14 +175,18 @@ fn mul_mat_vec_via_q8_1(
|
||||
if data_elems < ncols * nrows {
|
||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||
}
|
||||
if y.len() != ncols {
|
||||
if y.len() != ncols * b_size {
|
||||
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
||||
}
|
||||
if b_size == 0 || b_size > 4 {
|
||||
crate::bail!("only bsize between 1 and 4 are supported, got {b_size}")
|
||||
}
|
||||
// Start by quantizing y
|
||||
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
||||
let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let y_size_in_bytes =
|
||||
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||
quantize_q8_1(y, &mut y_q8_1, ncols, 1, dev)?;
|
||||
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
|
||||
|
||||
let kernel_name = match dtype {
|
||||
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
|
||||
@ -196,10 +201,16 @@ fn mul_mat_vec_via_q8_1(
|
||||
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
||||
let kernel_name = format!("{kernel_name}{b_size}");
|
||||
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
||||
let nblocks = if b_size == 1 {
|
||||
nrows as u32
|
||||
} else {
|
||||
(nrows as u32 + 1) / 2
|
||||
};
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (nrows as u32, 1, 1),
|
||||
grid_dim: (nblocks, 1, 1),
|
||||
block_dim: (WARP_SIZE as u32, 4, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
@ -210,7 +221,7 @@ fn mul_mat_vec_via_q8_1(
|
||||
&dst,
|
||||
/* ncols_x */ ncols as i32,
|
||||
/* nrows_x */ nrows as i32,
|
||||
/* nrows_y */ ncols as i32,
|
||||
/* nrows_y */ ncols_padded as i32,
|
||||
/* nrows_dst */ nrows as i32,
|
||||
);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
@ -384,7 +395,17 @@ impl QCudaStorage {
|
||||
storage: &CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) {
|
||||
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
1
|
||||
} else {
|
||||
4
|
||||
};
|
||||
let use_vec_kernel = match layout.shape().dims() {
|
||||
[b, m, _k] => b * m <= max_bm,
|
||||
[b, _k] => *b <= max_bm,
|
||||
_ => false,
|
||||
};
|
||||
if use_vec_kernel {
|
||||
self.dequantize_matmul_vec(self_shape, storage, layout)
|
||||
} else {
|
||||
self.dequantize_matmul(self_shape, storage, layout)
|
||||
@ -405,25 +426,31 @@ impl QCudaStorage {
|
||||
Some((o1, o2)) => rhs.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
|
||||
};
|
||||
let (with_batch, k) = match rhs_l.shape().dims() {
|
||||
[1, 1, k] => (true, k),
|
||||
[1, k] => (false, k),
|
||||
let (b_size, k) = match rhs_l.shape().dims() {
|
||||
[b, m, k] => (b * m, *k),
|
||||
[b, k] => (*b, *k),
|
||||
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
|
||||
};
|
||||
if ncols != *k {
|
||||
if ncols != k {
|
||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
|
||||
}
|
||||
|
||||
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
|
||||
} else {
|
||||
mul_mat_vec_via_q8_1(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
|
||||
};
|
||||
let out_shape = if with_batch {
|
||||
vec![1, 1, nrows]
|
||||
} else {
|
||||
vec![1, nrows]
|
||||
mul_mat_vec_via_q8_1(
|
||||
&self.data,
|
||||
&rhs,
|
||||
self.dtype,
|
||||
ncols,
|
||||
nrows,
|
||||
b_size,
|
||||
self.device(),
|
||||
)?
|
||||
};
|
||||
let mut out_shape = rhs_l.shape().dims().to_vec();
|
||||
out_shape.pop();
|
||||
out_shape.push(nrows);
|
||||
Ok((out, out_shape.into()))
|
||||
}
|
||||
|
||||
@ -522,6 +549,7 @@ mod test {
|
||||
/* dtype */ GgmlDType::Q4_0,
|
||||
/* ncols */ ncols,
|
||||
/* nrows */ 1,
|
||||
/* b_size */ 1,
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
|
Reference in New Issue
Block a user