mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add the mmv kernels for small batch sizes. (#2075)
* Add the mmv kernels for smaller sizes. * Support more mmv kernels. * Use the new kernels. * Fix the call. * Silly fix. * Improve the testing. * Fix for dmmv. * Add another dedicated test for the batching mmv.
This commit is contained in:
@ -166,6 +166,7 @@ fn mul_mat_vec_via_q8_1(
|
|||||||
dtype: GgmlDType,
|
dtype: GgmlDType,
|
||||||
ncols: usize,
|
ncols: usize,
|
||||||
nrows: usize,
|
nrows: usize,
|
||||||
|
b_size: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
use cudarc::driver::LaunchAsync;
|
use cudarc::driver::LaunchAsync;
|
||||||
@ -174,14 +175,18 @@ fn mul_mat_vec_via_q8_1(
|
|||||||
if data_elems < ncols * nrows {
|
if data_elems < ncols * nrows {
|
||||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||||
}
|
}
|
||||||
if y.len() != ncols {
|
if y.len() != ncols * b_size {
|
||||||
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
||||||
}
|
}
|
||||||
|
if b_size == 0 || b_size > 4 {
|
||||||
|
crate::bail!("only bsize between 1 and 4 are supported, got {b_size}")
|
||||||
|
}
|
||||||
// Start by quantizing y
|
// Start by quantizing y
|
||||||
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
||||||
let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
let y_size_in_bytes =
|
||||||
|
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||||
quantize_q8_1(y, &mut y_q8_1, ncols, 1, dev)?;
|
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
|
||||||
|
|
||||||
let kernel_name = match dtype {
|
let kernel_name = match dtype {
|
||||||
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
|
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
|
||||||
@ -196,10 +201,16 @@ fn mul_mat_vec_via_q8_1(
|
|||||||
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
|
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
|
||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
let kernel_name = format!("{kernel_name}{b_size}");
|
||||||
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
|
||||||
|
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
||||||
|
let nblocks = if b_size == 1 {
|
||||||
|
nrows as u32
|
||||||
|
} else {
|
||||||
|
(nrows as u32 + 1) / 2
|
||||||
|
};
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
grid_dim: (nrows as u32, 1, 1),
|
grid_dim: (nblocks, 1, 1),
|
||||||
block_dim: (WARP_SIZE as u32, 4, 1),
|
block_dim: (WARP_SIZE as u32, 4, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
@ -210,7 +221,7 @@ fn mul_mat_vec_via_q8_1(
|
|||||||
&dst,
|
&dst,
|
||||||
/* ncols_x */ ncols as i32,
|
/* ncols_x */ ncols as i32,
|
||||||
/* nrows_x */ nrows as i32,
|
/* nrows_x */ nrows as i32,
|
||||||
/* nrows_y */ ncols as i32,
|
/* nrows_y */ ncols_padded as i32,
|
||||||
/* nrows_dst */ nrows as i32,
|
/* nrows_dst */ nrows as i32,
|
||||||
);
|
);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
@ -384,7 +395,17 @@ impl QCudaStorage {
|
|||||||
storage: &CudaStorage,
|
storage: &CudaStorage,
|
||||||
layout: &crate::Layout,
|
layout: &crate::Layout,
|
||||||
) -> Result<(CudaStorage, crate::Shape)> {
|
) -> Result<(CudaStorage, crate::Shape)> {
|
||||||
if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) {
|
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
4
|
||||||
|
};
|
||||||
|
let use_vec_kernel = match layout.shape().dims() {
|
||||||
|
[b, m, _k] => b * m <= max_bm,
|
||||||
|
[b, _k] => *b <= max_bm,
|
||||||
|
_ => false,
|
||||||
|
};
|
||||||
|
if use_vec_kernel {
|
||||||
self.dequantize_matmul_vec(self_shape, storage, layout)
|
self.dequantize_matmul_vec(self_shape, storage, layout)
|
||||||
} else {
|
} else {
|
||||||
self.dequantize_matmul(self_shape, storage, layout)
|
self.dequantize_matmul(self_shape, storage, layout)
|
||||||
@ -405,25 +426,31 @@ impl QCudaStorage {
|
|||||||
Some((o1, o2)) => rhs.slice(o1..o2),
|
Some((o1, o2)) => rhs.slice(o1..o2),
|
||||||
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
|
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
|
||||||
};
|
};
|
||||||
let (with_batch, k) = match rhs_l.shape().dims() {
|
let (b_size, k) = match rhs_l.shape().dims() {
|
||||||
[1, 1, k] => (true, k),
|
[b, m, k] => (b * m, *k),
|
||||||
[1, k] => (false, k),
|
[b, k] => (*b, *k),
|
||||||
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
|
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
|
||||||
};
|
};
|
||||||
if ncols != *k {
|
if ncols != k {
|
||||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
|
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
|
||||||
}
|
}
|
||||||
|
|
||||||
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||||
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
|
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
|
||||||
} else {
|
} else {
|
||||||
mul_mat_vec_via_q8_1(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
|
mul_mat_vec_via_q8_1(
|
||||||
};
|
&self.data,
|
||||||
let out_shape = if with_batch {
|
&rhs,
|
||||||
vec![1, 1, nrows]
|
self.dtype,
|
||||||
} else {
|
ncols,
|
||||||
vec![1, nrows]
|
nrows,
|
||||||
|
b_size,
|
||||||
|
self.device(),
|
||||||
|
)?
|
||||||
};
|
};
|
||||||
|
let mut out_shape = rhs_l.shape().dims().to_vec();
|
||||||
|
out_shape.pop();
|
||||||
|
out_shape.push(nrows);
|
||||||
Ok((out, out_shape.into()))
|
Ok((out, out_shape.into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -522,6 +549,7 @@ mod test {
|
|||||||
/* dtype */ GgmlDType::Q4_0,
|
/* dtype */ GgmlDType::Q4_0,
|
||||||
/* ncols */ ncols,
|
/* ncols */ ncols,
|
||||||
/* nrows */ 1,
|
/* nrows */ 1,
|
||||||
|
/* b_size */ 1,
|
||||||
&dev,
|
&dev,
|
||||||
)?;
|
)?;
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
|
@ -170,12 +170,46 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
|||||||
let res2 = matmul.forward(&lhs2)?;
|
let res2 = matmul.forward(&lhs2)?;
|
||||||
let res2 = res2.i(1)?;
|
let res2 = res2.i(1)?;
|
||||||
let diff = (res - res2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
let diff = (res - res2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
assert_eq!(diff, 0.);
|
if device.is_cuda() {
|
||||||
|
assert!(diff < 0.1);
|
||||||
|
} else {
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn qmm_batch(dev: &Device) -> Result<()> {
|
||||||
|
let (lhs, rhs, _mm) = get_random_tensors(2, 256, 6, dev)?;
|
||||||
|
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?;
|
||||||
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
assert_eq!(mm.shape().dims(), [2, 6]);
|
||||||
|
let lhs2 = Tensor::cat(&[&lhs, &lhs], 0)?;
|
||||||
|
let mm2 = rhs.forward(&lhs2)?;
|
||||||
|
assert_eq!(mm2.shape().dims(), [4, 6]);
|
||||||
|
let diff2 = (mm2.i(2..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff2, 0.0);
|
||||||
|
let lhs3 = Tensor::cat(&[&lhs2, &lhs], 0)?;
|
||||||
|
let mm3 = rhs.forward(&lhs3)?;
|
||||||
|
assert_eq!(mm3.shape().dims(), [6, 6]);
|
||||||
|
let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
if dev.is_cuda() {
|
||||||
|
assert!(diff3 < 1e-4)
|
||||||
|
} else {
|
||||||
|
assert_eq!(diff3, 0.0)
|
||||||
|
};
|
||||||
|
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
if dev.is_cuda() {
|
||||||
|
assert!(diff3 < 1e-4)
|
||||||
|
} else {
|
||||||
|
assert_eq!(diff3, 0.0)
|
||||||
|
};
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(quantized_matmul, qmm_cpu, qmm_cuda, qmm_metal);
|
test_device!(quantized_matmul, qmm_cpu, qmm_cuda, qmm_metal);
|
||||||
test_device!(quantized_matmul_neg, qmm_n_cpu, qmm_n_cuda, qmm_n_metal);
|
test_device!(quantized_matmul_neg, qmm_n_cpu, qmm_n_cuda, qmm_n_metal);
|
||||||
|
test_device!(qmm_batch, qmm_b_cpu, qmm_b_cuda, qmm_b_metal);
|
||||||
|
|
||||||
fn quantize_q4_0(device: &Device) -> Result<()> {
|
fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
|
@ -2648,7 +2648,8 @@ static __device__ void mul_mat_vec_q(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda(
|
// batch size = 1
|
||||||
|
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda1(
|
||||||
const void * vx, const void * vy, float * dst,
|
const void * vx, const void * vy, float * dst,
|
||||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
@ -2656,7 +2657,7 @@ extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda(
|
|||||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda(
|
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda1(
|
||||||
const void * vx, const void * vy, float * dst,
|
const void * vx, const void * vy, float * dst,
|
||||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
@ -2664,7 +2665,7 @@ extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda(
|
|||||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda(
|
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda1(
|
||||||
const void * vx, const void * vy, float * dst,
|
const void * vx, const void * vy, float * dst,
|
||||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
@ -2672,7 +2673,7 @@ extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda(
|
|||||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda(
|
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda1(
|
||||||
const void * vx, const void * vy, float * dst,
|
const void * vx, const void * vy, float * dst,
|
||||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
@ -2680,7 +2681,7 @@ extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda(
|
|||||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda(
|
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda1(
|
||||||
const void * vx, const void * vy, float * dst,
|
const void * vx, const void * vy, float * dst,
|
||||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
@ -2688,7 +2689,7 @@ extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda(
|
|||||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda(
|
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda1(
|
||||||
const void * vx, const void * vy, float * dst,
|
const void * vx, const void * vy, float * dst,
|
||||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
@ -2696,7 +2697,7 @@ extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda(
|
|||||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda(
|
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda1(
|
||||||
const void * vx, const void * vy, float * dst,
|
const void * vx, const void * vy, float * dst,
|
||||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
@ -2704,7 +2705,7 @@ extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda(
|
|||||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda(
|
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda1(
|
||||||
const void * vx, const void * vy, float * dst,
|
const void * vx, const void * vy, float * dst,
|
||||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
@ -2712,7 +2713,7 @@ extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda(
|
|||||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda(
|
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda1(
|
||||||
const void * vx, const void * vy, float * dst,
|
const void * vx, const void * vy, float * dst,
|
||||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
@ -2720,7 +2721,7 @@ extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda(
|
|||||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda(
|
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda1(
|
||||||
const void * vx, const void * vy, float * dst,
|
const void * vx, const void * vy, float * dst,
|
||||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
@ -2728,6 +2729,249 @@ extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda(
|
|||||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// batch size = 2
|
||||||
|
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda2(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<2, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda2(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<2, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda2(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<2, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda2(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<2, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda2(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<2, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda2(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<2, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda2(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<2, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda2(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<2, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda2(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<2, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda2(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<2, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
// batch size = 3
|
||||||
|
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda3(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<3, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda3(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<3, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda3(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<3, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda3(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<3, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda3(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<3, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda3(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<3, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda3(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<3, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda3(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<3, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda3(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<3, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda3(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<3, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
// batch size = 4
|
||||||
|
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda4(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<4, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda4(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<4, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda4(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<4, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda4(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<4, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda4(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<4, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda4(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<4, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda4(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<4, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda4(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<4, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda4(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<4, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda4(
|
||||||
|
const void * vx, const void * vy, float * dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
|
mul_mat_vec_q<4, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
|
}
|
||||||
|
|
||||||
extern "C" __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
|
extern "C" __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
|
||||||
const int ix = blockDim.x*blockIdx.x + threadIdx.x;
|
const int ix = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user