mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Simd128 version of q6k vec-dot. (#1015)
* Add a specific function for the simd128 q6k vec-dot. * Simdification. * More simdification.
This commit is contained in:
@ -1539,6 +1539,9 @@ impl GgmlType for BlockQ6K {
|
|||||||
#[cfg(target_feature = "neon")]
|
#[cfg(target_feature = "neon")]
|
||||||
return super::neon::vec_dot_q6k_q8k(n, xs, ys);
|
return super::neon::vec_dot_q6k_q8k(n, xs, ys);
|
||||||
|
|
||||||
|
#[cfg(target_feature = "simd128")]
|
||||||
|
return super::simd128::vec_dot_q6k_q8k(n, xs, ys);
|
||||||
|
|
||||||
if n % QK_K != 0 {
|
if n % QK_K != 0 {
|
||||||
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
|
use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use byteorder::{ByteOrder, LittleEndian};
|
use byteorder::{ByteOrder, LittleEndian};
|
||||||
use half::f16;
|
use half::f16;
|
||||||
@ -272,3 +272,126 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
|
|||||||
Ok(sums)
|
Ok(sums)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||||
|
if n % QK_K != 0 {
|
||||||
|
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut aux8 = [0i8; QK_K];
|
||||||
|
unsafe {
|
||||||
|
let mut sums = f32x4_splat(0f32);
|
||||||
|
|
||||||
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||||
|
let q4 = &x.ql;
|
||||||
|
let qh = &x.qh;
|
||||||
|
let q8 = &y.qs;
|
||||||
|
let mut aux32 = f32x4_splat(0f32);
|
||||||
|
|
||||||
|
for j in (0..QK_K).step_by(128) {
|
||||||
|
let aux8 = aux8.as_mut_ptr().add(j);
|
||||||
|
let q4 = &q4.as_ptr().add(j / 2);
|
||||||
|
let qh = &qh.as_ptr().add(j / 4);
|
||||||
|
for l in (0..32).step_by(16) {
|
||||||
|
// aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8;
|
||||||
|
let a8 = v128_or(
|
||||||
|
v128_and(v128_load(q4.add(l) as *const v128), u8x16_splat(0xF)),
|
||||||
|
u8x16_shl(
|
||||||
|
v128_and(v128_load(qh.add(l) as *const v128), u8x16_splat(3)),
|
||||||
|
4,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||||
|
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||||
|
v128_store(
|
||||||
|
aux8.add(l) as *mut v128,
|
||||||
|
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||||
|
);
|
||||||
|
|
||||||
|
// aux8[l + 32] =
|
||||||
|
// (((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8;
|
||||||
|
let a8 = v128_or(
|
||||||
|
v128_and(v128_load(q4.add(l + 32) as *const v128), u8x16_splat(0xF)),
|
||||||
|
u8x16_shl(
|
||||||
|
v128_and(
|
||||||
|
u8x16_shr(v128_load(qh.add(l) as *const v128), 2),
|
||||||
|
u8x16_splat(3),
|
||||||
|
),
|
||||||
|
4,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||||
|
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||||
|
v128_store(
|
||||||
|
aux8.add(l + 32) as *mut v128,
|
||||||
|
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||||
|
);
|
||||||
|
|
||||||
|
// aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8;
|
||||||
|
let a8 = v128_or(
|
||||||
|
u8x16_shr(v128_load(q4.add(l) as *const v128), 4),
|
||||||
|
u8x16_shl(
|
||||||
|
v128_and(
|
||||||
|
u8x16_shr(v128_load(qh.add(l) as *const v128), 4),
|
||||||
|
u8x16_splat(3),
|
||||||
|
),
|
||||||
|
4,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||||
|
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||||
|
v128_store(
|
||||||
|
aux8.add(l + 64) as *mut v128,
|
||||||
|
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||||
|
);
|
||||||
|
|
||||||
|
// aux8[l + 96] =
|
||||||
|
// (((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8;
|
||||||
|
let a8 = v128_or(
|
||||||
|
u8x16_shr(v128_load(q4.add(l + 32) as *const v128), 4),
|
||||||
|
u8x16_shl(
|
||||||
|
v128_and(
|
||||||
|
u8x16_shr(v128_load(qh.add(l) as *const v128), 6),
|
||||||
|
u8x16_splat(3),
|
||||||
|
),
|
||||||
|
4,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||||
|
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||||
|
v128_store(
|
||||||
|
aux8.add(l + 96) as *mut v128,
|
||||||
|
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (j, &scale) in x.scales.iter().enumerate() {
|
||||||
|
let scale = f32x4_splat(scale as f32);
|
||||||
|
for offset in [0, 8] {
|
||||||
|
let aux16 = i16x8_mul(
|
||||||
|
i16x8_load_extend_i8x8(q8.as_ptr().add(16 * j + offset)),
|
||||||
|
i16x8_load_extend_i8x8(aux8.as_ptr().add(16 * j + offset)),
|
||||||
|
);
|
||||||
|
aux32 = f32x4_add(
|
||||||
|
aux32,
|
||||||
|
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_low_i16x8(aux16)), scale),
|
||||||
|
);
|
||||||
|
aux32 = f32x4_add(
|
||||||
|
aux32,
|
||||||
|
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_high_i16x8(aux16)), scale),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let d = f32x4_splat(x.d.to_f32() * y.d);
|
||||||
|
sums = f32x4_add(sums, f32x4_mul(aux32, d));
|
||||||
|
}
|
||||||
|
let sums = f32x4_extract_lane::<0>(sums)
|
||||||
|
+ f32x4_extract_lane::<1>(sums)
|
||||||
|
+ f32x4_extract_lane::<2>(sums)
|
||||||
|
+ f32x4_extract_lane::<3>(sums);
|
||||||
|
Ok(sums)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user