mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Simd128 version of the q2k-q8k vecdot product. (#1011)
* Sketch the simd128 version of q2k vecdot. * Use a single accumulator. * Simdify the q2k-q8k vecdot product. * Cosmetic change.
This commit is contained in:
@ -710,18 +710,17 @@ impl GgmlType for BlockQ2K {
|
||||
|
||||
let mut isum = 0;
|
||||
let mut is = 0;
|
||||
let mut d;
|
||||
for _ in 0..(QK_K / 128) {
|
||||
let mut shift = 0;
|
||||
for _ in 0..4 {
|
||||
d = (sc[is] & 0xF) as i32;
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
let mut isuml = 0;
|
||||
for l in 0..16 {
|
||||
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
|
||||
}
|
||||
isum += d * isuml;
|
||||
d = (sc[is] & 0xF) as i32;
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
isuml = 0;
|
||||
for l in 16..32 {
|
||||
@ -1086,7 +1085,6 @@ impl GgmlType for BlockQ3K {
|
||||
let d_all = block.d.to_f32();
|
||||
let mut m = 1;
|
||||
let mut is = 0;
|
||||
let mut dl;
|
||||
|
||||
// Dequantize both 128 long blocks
|
||||
// 32 qs values per 128 long block
|
||||
@ -1097,7 +1095,7 @@ impl GgmlType for BlockQ3K {
|
||||
for (scale_index, scale_scoped_y) in
|
||||
shift_scoped_y.chunks_exact_mut(16).enumerate()
|
||||
{
|
||||
dl = d_all * (scales[is] as f32 - 32.0);
|
||||
let dl = d_all * (scales[is] as f32 - 32.0);
|
||||
for (i, inner_y) in scale_scoped_y.iter_mut().enumerate() {
|
||||
let new_y = dl
|
||||
* (((qs[i + 16 * scale_index] >> shift) & 3) as i8
|
||||
|
@ -102,53 +102,85 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
||||
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut sumf = 0.0;
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let mut q2: &[_] = &x.qs;
|
||||
let mut q8: &[_] = &y.qs;
|
||||
let sc = &x.scales;
|
||||
unsafe {
|
||||
let mut sumf = f32x4_splat(0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let mut q2: &[_] = &x.qs;
|
||||
let mut q8: &[_] = &y.qs;
|
||||
let sc = &x.scales;
|
||||
|
||||
let mut summs = 0;
|
||||
for (bsum, scale) in y.bsums.iter().zip(sc) {
|
||||
summs += *bsum as i32 * ((scale >> 4) as i32);
|
||||
}
|
||||
|
||||
let dall = y.d * x.d.to_f32();
|
||||
let dmin = y.d * x.dmin.to_f32();
|
||||
|
||||
let mut isum = 0;
|
||||
let mut is = 0;
|
||||
let mut d;
|
||||
for _ in 0..(QK_K / 128) {
|
||||
let mut shift = 0;
|
||||
for _ in 0..4 {
|
||||
d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
let mut isuml = 0;
|
||||
for l in 0..16 {
|
||||
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
|
||||
}
|
||||
isum += d * isuml;
|
||||
d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
isuml = 0;
|
||||
for l in 16..32 {
|
||||
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
|
||||
}
|
||||
isum += d * isuml;
|
||||
shift += 2;
|
||||
// adjust the indexing
|
||||
q8 = &q8[32..];
|
||||
let mut summs = i32x4_splat(0);
|
||||
for i in (0..(QK_K / 16)).step_by(4) {
|
||||
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(i));
|
||||
let scales = i32x4_shr(
|
||||
i32x4(
|
||||
sc[i] as i32,
|
||||
sc[i + 1] as i32,
|
||||
sc[i + 2] as i32,
|
||||
sc[i + 3] as i32,
|
||||
),
|
||||
4,
|
||||
);
|
||||
summs = i32x4_add(summs, i32x4_mul(bsums, scales))
|
||||
}
|
||||
// adjust the indexing
|
||||
q2 = &q2[32..];
|
||||
}
|
||||
sumf += dall * isum as f32 - dmin * summs as f32;
|
||||
}
|
||||
let summs = f32x4_convert_i32x4(summs);
|
||||
|
||||
Ok(sumf)
|
||||
let dall = y.d * x.d.to_f32();
|
||||
let dmin = y.d * x.dmin.to_f32();
|
||||
|
||||
let mut isum = i32x4_splat(0);
|
||||
let mut is = 0;
|
||||
for _ in 0..(QK_K / 128) {
|
||||
let mut shift = 0;
|
||||
for _ in 0..4 {
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
let mut isuml = i16x8_splat(0);
|
||||
for l in (0..16).step_by(8) {
|
||||
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
|
||||
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
|
||||
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
|
||||
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
|
||||
}
|
||||
let dd = i32x4_splat(d);
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
let mut isuml = i16x8_splat(0);
|
||||
for l in (16..32).step_by(8) {
|
||||
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
|
||||
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
|
||||
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
|
||||
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
|
||||
}
|
||||
let dd = i32x4_splat(d);
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
|
||||
shift += 2;
|
||||
// adjust the indexing
|
||||
q8 = &q8[32..];
|
||||
}
|
||||
// adjust the indexing
|
||||
q2 = &q2[32..];
|
||||
}
|
||||
let isum = f32x4_convert_i32x4(isum);
|
||||
sumf = f32x4_add(
|
||||
sumf,
|
||||
f32x4_sub(
|
||||
f32x4_mul(isum, f32x4_splat(dall)),
|
||||
f32x4_mul(summs, f32x4_splat(dmin)),
|
||||
),
|
||||
);
|
||||
}
|
||||
let sumf = f32x4_extract_lane::<0>(sumf)
|
||||
+ f32x4_extract_lane::<1>(sumf)
|
||||
+ f32x4_extract_lane::<2>(sumf)
|
||||
+ f32x4_extract_lane::<3>(sumf);
|
||||
Ok(sumf)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
|
@ -514,7 +514,7 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
||||
|
||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
|
||||
|
||||
if error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||
candle_core::bail!(
|
||||
"Dot product error {} exceeds max error {}",
|
||||
error,
|
||||
|
Reference in New Issue
Block a user