mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add the mmv kernels for small batch sizes. (#2075)
* Add the mmv kernels for smaller sizes. * Support more mmv kernels. * Use the new kernels. * Fix the call. * Silly fix. * Improve the testing. * Fix for dmmv. * Add another dedicated test for the batching mmv.
This commit is contained in:
@ -166,6 +166,7 @@ fn mul_mat_vec_via_q8_1(
|
||||
dtype: GgmlDType,
|
||||
ncols: usize,
|
||||
nrows: usize,
|
||||
b_size: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
@ -174,14 +175,18 @@ fn mul_mat_vec_via_q8_1(
|
||||
if data_elems < ncols * nrows {
|
||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||
}
|
||||
if y.len() != ncols {
|
||||
if y.len() != ncols * b_size {
|
||||
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
||||
}
|
||||
if b_size == 0 || b_size > 4 {
|
||||
crate::bail!("only bsize between 1 and 4 are supported, got {b_size}")
|
||||
}
|
||||
// Start by quantizing y
|
||||
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
||||
let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let y_size_in_bytes =
|
||||
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||
quantize_q8_1(y, &mut y_q8_1, ncols, 1, dev)?;
|
||||
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
|
||||
|
||||
let kernel_name = match dtype {
|
||||
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
|
||||
@ -196,10 +201,16 @@ fn mul_mat_vec_via_q8_1(
|
||||
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
||||
let kernel_name = format!("{kernel_name}{b_size}");
|
||||
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
||||
let nblocks = if b_size == 1 {
|
||||
nrows as u32
|
||||
} else {
|
||||
(nrows as u32 + 1) / 2
|
||||
};
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (nrows as u32, 1, 1),
|
||||
grid_dim: (nblocks, 1, 1),
|
||||
block_dim: (WARP_SIZE as u32, 4, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
@ -210,7 +221,7 @@ fn mul_mat_vec_via_q8_1(
|
||||
&dst,
|
||||
/* ncols_x */ ncols as i32,
|
||||
/* nrows_x */ nrows as i32,
|
||||
/* nrows_y */ ncols as i32,
|
||||
/* nrows_y */ ncols_padded as i32,
|
||||
/* nrows_dst */ nrows as i32,
|
||||
);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
@ -384,7 +395,17 @@ impl QCudaStorage {
|
||||
storage: &CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) {
|
||||
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
1
|
||||
} else {
|
||||
4
|
||||
};
|
||||
let use_vec_kernel = match layout.shape().dims() {
|
||||
[b, m, _k] => b * m <= max_bm,
|
||||
[b, _k] => *b <= max_bm,
|
||||
_ => false,
|
||||
};
|
||||
if use_vec_kernel {
|
||||
self.dequantize_matmul_vec(self_shape, storage, layout)
|
||||
} else {
|
||||
self.dequantize_matmul(self_shape, storage, layout)
|
||||
@ -405,25 +426,31 @@ impl QCudaStorage {
|
||||
Some((o1, o2)) => rhs.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
|
||||
};
|
||||
let (with_batch, k) = match rhs_l.shape().dims() {
|
||||
[1, 1, k] => (true, k),
|
||||
[1, k] => (false, k),
|
||||
let (b_size, k) = match rhs_l.shape().dims() {
|
||||
[b, m, k] => (b * m, *k),
|
||||
[b, k] => (*b, *k),
|
||||
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
|
||||
};
|
||||
if ncols != *k {
|
||||
if ncols != k {
|
||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
|
||||
}
|
||||
|
||||
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
|
||||
} else {
|
||||
mul_mat_vec_via_q8_1(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
|
||||
};
|
||||
let out_shape = if with_batch {
|
||||
vec![1, 1, nrows]
|
||||
} else {
|
||||
vec![1, nrows]
|
||||
mul_mat_vec_via_q8_1(
|
||||
&self.data,
|
||||
&rhs,
|
||||
self.dtype,
|
||||
ncols,
|
||||
nrows,
|
||||
b_size,
|
||||
self.device(),
|
||||
)?
|
||||
};
|
||||
let mut out_shape = rhs_l.shape().dims().to_vec();
|
||||
out_shape.pop();
|
||||
out_shape.push(nrows);
|
||||
Ok((out, out_shape.into()))
|
||||
}
|
||||
|
||||
@ -522,6 +549,7 @@ mod test {
|
||||
/* dtype */ GgmlDType::Q4_0,
|
||||
/* ncols */ ncols,
|
||||
/* nrows */ 1,
|
||||
/* b_size */ 1,
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
|
@ -170,12 +170,46 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||
let res2 = matmul.forward(&lhs2)?;
|
||||
let res2 = res2.i(1)?;
|
||||
let diff = (res - res2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
if device.is_cuda() {
|
||||
assert!(diff < 0.1);
|
||||
} else {
|
||||
assert_eq!(diff, 0.);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn qmm_batch(dev: &Device) -> Result<()> {
|
||||
let (lhs, rhs, _mm) = get_random_tensors(2, 256, 6, dev)?;
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
assert_eq!(mm.shape().dims(), [2, 6]);
|
||||
let lhs2 = Tensor::cat(&[&lhs, &lhs], 0)?;
|
||||
let mm2 = rhs.forward(&lhs2)?;
|
||||
assert_eq!(mm2.shape().dims(), [4, 6]);
|
||||
let diff2 = (mm2.i(2..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff2, 0.0);
|
||||
let lhs3 = Tensor::cat(&[&lhs2, &lhs], 0)?;
|
||||
let mm3 = rhs.forward(&lhs3)?;
|
||||
assert_eq!(mm3.shape().dims(), [6, 6]);
|
||||
let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if dev.is_cuda() {
|
||||
assert!(diff3 < 1e-4)
|
||||
} else {
|
||||
assert_eq!(diff3, 0.0)
|
||||
};
|
||||
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if dev.is_cuda() {
|
||||
assert!(diff3 < 1e-4)
|
||||
} else {
|
||||
assert_eq!(diff3, 0.0)
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(quantized_matmul, qmm_cpu, qmm_cuda, qmm_metal);
|
||||
test_device!(quantized_matmul_neg, qmm_n_cpu, qmm_n_cuda, qmm_n_metal);
|
||||
test_device!(qmm_batch, qmm_b_cpu, qmm_b_cuda, qmm_b_metal);
|
||||
|
||||
fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
|
@ -2648,7 +2648,8 @@ static __device__ void mul_mat_vec_q(
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda(
|
||||
// batch size = 1
|
||||
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda1(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
@ -2656,7 +2657,7 @@ extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda(
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda(
|
||||
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda1(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
@ -2664,7 +2665,7 @@ extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda(
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda(
|
||||
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda1(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
@ -2672,7 +2673,7 @@ extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda(
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda(
|
||||
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda1(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
@ -2680,7 +2681,7 @@ extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda(
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda(
|
||||
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda1(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
@ -2688,7 +2689,7 @@ extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda(
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda(
|
||||
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda1(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
@ -2696,7 +2697,7 @@ extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda(
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda(
|
||||
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda1(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
@ -2704,7 +2705,7 @@ extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda(
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda(
|
||||
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda1(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
@ -2712,7 +2713,7 @@ extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda(
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda(
|
||||
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda1(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
@ -2720,7 +2721,7 @@ extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda(
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda(
|
||||
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda1(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
@ -2728,6 +2729,249 @@ extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda(
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
// batch size = 2
|
||||
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda2(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<2, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda2(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<2, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda2(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<2, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda2(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<2, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda2(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<2, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda2(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<2, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda2(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<2, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda2(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<2, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda2(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<2, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda2(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<2, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
// batch size = 3
|
||||
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda3(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<3, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda3(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<3, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda3(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<3, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda3(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<3, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda3(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<3, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda3(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<3, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda3(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<3, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda3(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<3, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda3(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<3, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda3(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<3, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
// batch size = 4
|
||||
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda4(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<4, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda4(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<4, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda4(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<4, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda4(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<4, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda4(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<4, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda4(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<4, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda4(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<4, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda4(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<4, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda4(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<4, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda4(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<4, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
|
||||
const int ix = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
|
Reference in New Issue
Block a user