Add the mmv kernels for small batch sizes. (#2075)

* Add the mmv kernels for smaller sizes.

* Support more mmv kernels.

* Use the new kernels.

* Fix the call.

* Silly fix.

* Improve the testing.

* Fix for dmmv.

* Add another dedicated test for the batching mmv.
This commit is contained in:
Laurent Mazare
2024-04-16 21:30:51 +02:00
committed by GitHub
parent 4d14777673
commit 2817643db9
3 changed files with 335 additions and 29 deletions

View File

@ -166,6 +166,7 @@ fn mul_mat_vec_via_q8_1(
dtype: GgmlDType, dtype: GgmlDType,
ncols: usize, ncols: usize,
nrows: usize, nrows: usize,
b_size: usize,
dev: &CudaDevice, dev: &CudaDevice,
) -> Result<CudaStorage> { ) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync; use cudarc::driver::LaunchAsync;
@ -174,14 +175,18 @@ fn mul_mat_vec_via_q8_1(
if data_elems < ncols * nrows { if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems) crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
} }
if y.len() != ncols { if y.len() != ncols * b_size {
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len()) crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
} }
if b_size == 0 || b_size > 4 {
crate::bail!("only bsize between 1 and 4 are supported, got {b_size}")
}
// Start by quantizing y // Start by quantizing y
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING); let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); let y_size_in_bytes =
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? }; let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, ncols, 1, dev)?; quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
let kernel_name = match dtype { let kernel_name = match dtype {
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda", GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
@ -196,10 +201,16 @@ fn mul_mat_vec_via_q8_1(
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda", GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
}; };
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; let kernel_name = format!("{kernel_name}{b_size}");
let dst = unsafe { dev.alloc::<f32>(nrows).w()? }; let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
let nblocks = if b_size == 1 {
nrows as u32
} else {
(nrows as u32 + 1) / 2
};
let cfg = cudarc::driver::LaunchConfig { let cfg = cudarc::driver::LaunchConfig {
grid_dim: (nrows as u32, 1, 1), grid_dim: (nblocks, 1, 1),
block_dim: (WARP_SIZE as u32, 4, 1), block_dim: (WARP_SIZE as u32, 4, 1),
shared_mem_bytes: 0, shared_mem_bytes: 0,
}; };
@ -210,7 +221,7 @@ fn mul_mat_vec_via_q8_1(
&dst, &dst,
/* ncols_x */ ncols as i32, /* ncols_x */ ncols as i32,
/* nrows_x */ nrows as i32, /* nrows_x */ nrows as i32,
/* nrows_y */ ncols as i32, /* nrows_y */ ncols_padded as i32,
/* nrows_dst */ nrows as i32, /* nrows_dst */ nrows as i32,
); );
unsafe { func.launch(cfg, params) }.w()?; unsafe { func.launch(cfg, params) }.w()?;
@ -384,7 +395,17 @@ impl QCudaStorage {
storage: &CudaStorage, storage: &CudaStorage,
layout: &crate::Layout, layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> { ) -> Result<(CudaStorage, crate::Shape)> {
if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) { let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
1
} else {
4
};
let use_vec_kernel = match layout.shape().dims() {
[b, m, _k] => b * m <= max_bm,
[b, _k] => *b <= max_bm,
_ => false,
};
if use_vec_kernel {
self.dequantize_matmul_vec(self_shape, storage, layout) self.dequantize_matmul_vec(self_shape, storage, layout)
} else { } else {
self.dequantize_matmul(self_shape, storage, layout) self.dequantize_matmul(self_shape, storage, layout)
@ -405,25 +426,31 @@ impl QCudaStorage {
Some((o1, o2)) => rhs.slice(o1..o2), Some((o1, o2)) => rhs.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?, None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
}; };
let (with_batch, k) = match rhs_l.shape().dims() { let (b_size, k) = match rhs_l.shape().dims() {
[1, 1, k] => (true, k), [b, m, k] => (b * m, *k),
[1, k] => (false, k), [b, k] => (*b, *k),
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()), _ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
}; };
if ncols != *k { if ncols != k {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape()) crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
} }
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) { let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())? dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
} else { } else {
mul_mat_vec_via_q8_1(&self.data, &rhs, self.dtype, ncols, nrows, self.device())? mul_mat_vec_via_q8_1(
}; &self.data,
let out_shape = if with_batch { &rhs,
vec![1, 1, nrows] self.dtype,
} else { ncols,
vec![1, nrows] nrows,
b_size,
self.device(),
)?
}; };
let mut out_shape = rhs_l.shape().dims().to_vec();
out_shape.pop();
out_shape.push(nrows);
Ok((out, out_shape.into())) Ok((out, out_shape.into()))
} }
@ -522,6 +549,7 @@ mod test {
/* dtype */ GgmlDType::Q4_0, /* dtype */ GgmlDType::Q4_0,
/* ncols */ ncols, /* ncols */ ncols,
/* nrows */ 1, /* nrows */ 1,
/* b_size */ 1,
&dev, &dev,
)?; )?;
let vs = cuda_storage.as_cuda_slice::<f32>()?; let vs = cuda_storage.as_cuda_slice::<f32>()?;

View File

@ -170,12 +170,46 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
let res2 = matmul.forward(&lhs2)?; let res2 = matmul.forward(&lhs2)?;
let res2 = res2.i(1)?; let res2 = res2.i(1)?;
let diff = (res - res2)?.abs()?.sum_all()?.to_vec0::<f32>()?; let diff = (res - res2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0.); if device.is_cuda() {
assert!(diff < 0.1);
} else {
assert_eq!(diff, 0.);
}
Ok(())
}
fn qmm_batch(dev: &Device) -> Result<()> {
let (lhs, rhs, _mm) = get_random_tensors(2, 256, 6, dev)?;
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.shape().dims(), [2, 6]);
let lhs2 = Tensor::cat(&[&lhs, &lhs], 0)?;
let mm2 = rhs.forward(&lhs2)?;
assert_eq!(mm2.shape().dims(), [4, 6]);
let diff2 = (mm2.i(2..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff2, 0.0);
let lhs3 = Tensor::cat(&[&lhs2, &lhs], 0)?;
let mm3 = rhs.forward(&lhs3)?;
assert_eq!(mm3.shape().dims(), [6, 6]);
let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if dev.is_cuda() {
assert!(diff3 < 1e-4)
} else {
assert_eq!(diff3, 0.0)
};
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if dev.is_cuda() {
assert!(diff3 < 1e-4)
} else {
assert_eq!(diff3, 0.0)
};
Ok(()) Ok(())
} }
test_device!(quantized_matmul, qmm_cpu, qmm_cuda, qmm_metal); test_device!(quantized_matmul, qmm_cpu, qmm_cuda, qmm_metal);
test_device!(quantized_matmul_neg, qmm_n_cpu, qmm_n_cuda, qmm_n_metal); test_device!(quantized_matmul_neg, qmm_n_cpu, qmm_n_cuda, qmm_n_metal);
test_device!(qmm_batch, qmm_b_cpu, qmm_b_cuda, qmm_b_metal);
fn quantize_q4_0(device: &Device) -> Result<()> { fn quantize_q4_0(device: &Device) -> Result<()> {
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();

View File

@ -2648,7 +2648,8 @@ static __device__ void mul_mat_vec_q(
} }
} }
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda( // batch size = 1
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda1(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@ -2656,7 +2657,7 @@ extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
} }
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda( extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda1(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@ -2664,7 +2665,7 @@ extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
} }
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda( extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda1(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@ -2672,7 +2673,7 @@ extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
} }
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda( extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda1(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@ -2680,7 +2681,7 @@ extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
} }
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda( extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda1(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@ -2688,7 +2689,7 @@ extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
} }
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda( extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda1(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@ -2696,7 +2697,7 @@ extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
} }
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda( extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda1(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@ -2704,7 +2705,7 @@ extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
} }
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda( extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda1(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@ -2712,7 +2713,7 @@ extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
} }
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda( extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda1(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@ -2720,7 +2721,7 @@ extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
} }
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda( extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda1(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@ -2728,6 +2729,249 @@ extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
} }
// batch size = 2
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda2(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<2, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda2(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<2, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda2(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<2, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda2(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<2, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda2(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<2, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda2(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<2, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda2(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<2, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda2(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<2, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda2(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<2, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda2(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<2, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
// batch size = 3
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda3(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<3, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda3(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<3, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda3(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<3, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda3(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<3, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda3(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<3, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda3(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<3, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda3(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<3, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda3(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<3, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda3(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<3, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda3(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<3, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
// batch size = 4
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda4(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<4, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda4(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<4, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda4(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<4, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda4(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<4, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda4(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<4, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda4(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<4, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda4(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<4, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda4(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<4, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda4(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<4, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda4(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
mul_mat_vec_q<4, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
extern "C" __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) { extern "C" __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
const int ix = blockDim.x*blockIdx.x + threadIdx.x; const int ix = blockDim.x*blockIdx.x + threadIdx.x;