mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00

* Simpler repro for the neon optimization issue. * Bugfix for q4k. * Improve the fix, share the dot-prod bit. * Clippy fixes. * Fix for q6k. * Also fix for q2k. * Use the new shared dotprod. * Add more testing.
614 lines
22 KiB
Rust
614 lines
22 KiB
Rust
use super::k_quants::{
|
|
BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K,
|
|
};
|
|
use crate::Result;
|
|
use byteorder::{ByteOrder, LittleEndian};
|
|
|
|
#[allow(unused_imports)]
|
|
#[cfg(target_arch = "arm")]
|
|
use core::arch::arm::*;
|
|
|
|
#[allow(unused_imports)]
|
|
#[cfg(target_arch = "aarch64")]
|
|
use core::arch::aarch64::*;
|
|
|
|
#[inline(always)]
|
|
unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t {
|
|
// TODO: dotprod
|
|
let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
|
|
let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
|
vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
|
let qk = QK8_0;
|
|
let nb = n / qk;
|
|
if n % QK8_0 != 0 {
|
|
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
|
}
|
|
|
|
unsafe {
|
|
let mut sumv0 = vdupq_n_f32(0.0f32);
|
|
for i in 0..nb {
|
|
let x0 = &xs[i];
|
|
let y0 = &ys[i];
|
|
|
|
let m4b = vdupq_n_u8(0x0F);
|
|
let s8b = vdupq_n_s8(0x8);
|
|
|
|
let v0_0 = vld1q_u8(x0.qs.as_ptr());
|
|
|
|
// 4-bit -> 8-bit
|
|
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
|
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
|
|
|
// sub 8
|
|
let v0_0ls = vsubq_s8(v0_0l, s8b);
|
|
let v0_0hs = vsubq_s8(v0_0h, s8b);
|
|
|
|
// load y
|
|
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
|
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
|
|
|
let pl0 = vdotq_s32(v0_0ls, v1_0l);
|
|
let ph0 = vdotq_s32(v0_0hs, v1_0h);
|
|
sumv0 = vmlaq_n_f32(
|
|
sumv0,
|
|
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
|
x0.d.to_f32() * y0.d.to_f32(),
|
|
);
|
|
}
|
|
Ok(vaddvq_f32(sumv0))
|
|
}
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
|
let qk = QK8_0;
|
|
if n % QK8_0 != 0 {
|
|
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
|
}
|
|
let nb = n / QK8_0;
|
|
unsafe {
|
|
let mut sumv0 = vdupq_n_f32(0.0f32);
|
|
for i in 0..nb {
|
|
let x0 = &xs[i];
|
|
let y0 = &ys[i];
|
|
|
|
let x0_0 = vld1q_s8(x0.qs.as_ptr());
|
|
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
|
|
|
|
// load y
|
|
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
|
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
|
|
|
let p0 = vdotq_s32(x0_0, y0_0);
|
|
let p1 = vdotq_s32(x0_1, y0_1);
|
|
|
|
sumv0 = vmlaq_n_f32(
|
|
sumv0,
|
|
vcvtq_f32_s32(vaddq_s32(p0, p1)),
|
|
x0.d.to_f32() * y0.d.to_f32(),
|
|
);
|
|
}
|
|
Ok(vaddvq_f32(sumv0))
|
|
}
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
|
let qk = QK_K;
|
|
if n % QK_K != 0 {
|
|
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
|
}
|
|
|
|
let mut sumf = 0f32;
|
|
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
|
unsafe {
|
|
let mut sum_i = vdupq_n_s32(0);
|
|
let scale = xs.d * ys.d;
|
|
let xs = xs.qs.as_ptr();
|
|
let ys = ys.qs.as_ptr();
|
|
for i in (0..QK_K).step_by(16) {
|
|
let xs = vld1q_s8(xs.add(i));
|
|
let ys = vld1q_s8(ys.add(i));
|
|
let xy = vdotq_s32(xs, ys);
|
|
sum_i = vaddq_s32(sum_i, xy)
|
|
}
|
|
sumf += vaddvq_s32(sum_i) as f32 * scale
|
|
}
|
|
}
|
|
Ok(sumf)
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
|
if n % QK_K != 0 {
|
|
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
|
}
|
|
let mut sum = 0f32;
|
|
unsafe {
|
|
let m4b = vdupq_n_u8(0xF);
|
|
|
|
let mone = vdupq_n_u8(3);
|
|
|
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
|
let d_all = x.d.to_f32();
|
|
|
|
let mut q6 = x.ql.as_ptr();
|
|
let mut qh = x.qh.as_ptr();
|
|
let mut q8 = y.qs.as_ptr();
|
|
|
|
let mut scale = x.scales.as_ptr();
|
|
|
|
let q8sums = vld1q_s16_x2(y.bsums.as_ptr());
|
|
let scales = vld1q_s8(scale);
|
|
let q6scales = int16x8x2_t(
|
|
vmovl_s8(vget_low_s8(scales)),
|
|
vmovl_s8(vget_high_s8(scales)),
|
|
);
|
|
|
|
let prod = vaddq_s32(
|
|
vaddq_s32(
|
|
vmull_s16(vget_low_s16(q8sums.0), vget_low_s16(q6scales.0)),
|
|
vmull_s16(vget_high_s16(q8sums.0), vget_high_s16(q6scales.0)),
|
|
),
|
|
vaddq_s32(
|
|
vmull_s16(vget_low_s16(q8sums.1), vget_low_s16(q6scales.1)),
|
|
vmull_s16(vget_high_s16(q8sums.1), vget_high_s16(q6scales.1)),
|
|
),
|
|
);
|
|
let isum_mins = vaddvq_s32(prod);
|
|
|
|
let mut isum = 0i32;
|
|
|
|
for _j in 0..QK_K / 128 {
|
|
let qhbits = vld1q_u8_x2(qh);
|
|
qh = qh.add(32);
|
|
let q6bits = vld1q_u8_x4(q6);
|
|
q6 = q6.add(64);
|
|
let q8bytes = vld1q_s8_x4(q8);
|
|
q8 = q8.add(64);
|
|
|
|
let q6h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4);
|
|
let q6h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4);
|
|
let shifted = vshrq_n_u8(qhbits.0, 2);
|
|
let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
|
let shifted = vshrq_n_u8(qhbits.1, 2);
|
|
let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
|
|
|
let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.0, m4b), q6h_0));
|
|
let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.1, m4b), q6h_1));
|
|
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
|
|
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
|
|
|
|
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
|
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
|
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
|
scale = scale.add(2);
|
|
|
|
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
|
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
|
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
|
scale = scale.add(2);
|
|
|
|
let q8bytes = vld1q_s8_x4(q8);
|
|
q8 = q8.add(64);
|
|
|
|
let shifted = vshrq_n_u8(qhbits.0, 4);
|
|
let q6h_0 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
|
let shifted = vshrq_n_u8(qhbits.1, 4);
|
|
let q6h_1 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
|
let shifted = vshrq_n_u8(qhbits.0, 6);
|
|
let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
|
let shifted = vshrq_n_u8(qhbits.1, 6);
|
|
let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
|
|
|
let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.0, 4), q6h_0));
|
|
let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.1, 4), q6h_1));
|
|
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
|
|
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
|
|
|
|
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
|
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
|
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
|
scale = scale.add(2);
|
|
|
|
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
|
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
|
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
|
scale = scale.add(2);
|
|
}
|
|
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
|
|
}
|
|
}
|
|
Ok(sum)
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> {
|
|
if n % QK_K != 0 {
|
|
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
|
|
}
|
|
let mut sumf = 0f32;
|
|
let mut utmp = [0u32; 4];
|
|
const KMASK1: u32 = 0x3f3f3f3f;
|
|
const KMASK2: u32 = 0x0f0f0f0f;
|
|
const KMASK3: u32 = 0x03030303;
|
|
|
|
unsafe {
|
|
let m4b = vdupq_n_u8(0xF);
|
|
let mone = vdupq_n_u8(1);
|
|
let mtwo = vdupq_n_u8(2);
|
|
|
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
|
let d = y.d * x.d.to_f32();
|
|
let dmin = y.d * x.dmin.to_f32();
|
|
|
|
let q8sums = vpaddq_s16(
|
|
vld1q_s16(y.bsums.as_ptr()),
|
|
vld1q_s16(y.bsums.as_ptr().add(8)),
|
|
);
|
|
|
|
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
|
|
|
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
|
|
let uaux = utmp[1] & KMASK1;
|
|
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
|
utmp[2] = uaux;
|
|
utmp[0] &= KMASK1;
|
|
|
|
let mins8 = vld1_u8((utmp.as_ptr() as *const u8).add(8));
|
|
let mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
|
|
let prod = vaddq_s32(
|
|
vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)),
|
|
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)),
|
|
);
|
|
let sumi_mins = vaddvq_s32(prod);
|
|
|
|
let mut scales = utmp.as_ptr() as *const u8;
|
|
|
|
let mut q5 = x.qs.as_ptr();
|
|
let mut q8 = y.qs.as_ptr();
|
|
|
|
let mut qhbits = vld1q_u8_x2(x.qh.as_ptr());
|
|
|
|
let mut sumi = 0i32;
|
|
|
|
for _j in 0..QK_K / 64 {
|
|
let q5bits = vld1q_u8_x2(q5);
|
|
q5 = q5.add(32);
|
|
let q8bytes = vld1q_s8_x4(q8);
|
|
q8 = q8.add(64);
|
|
|
|
let q5h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4);
|
|
let q5h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4);
|
|
let q5h_2 = vshlq_n_u8(vandq_u8(mtwo, qhbits.0), 3);
|
|
let q5h_3 = vshlq_n_u8(vandq_u8(mtwo, qhbits.1), 3);
|
|
qhbits.0 = vshrq_n_u8(qhbits.0, 2);
|
|
qhbits.1 = vshrq_n_u8(qhbits.1, 2);
|
|
|
|
let q5bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.0, m4b), q5h_0));
|
|
let q5bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.1, m4b), q5h_1));
|
|
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
|
|
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
|
|
|
|
let p0 = vdotq_s32(q5bytes_0, q8bytes.0);
|
|
let p1 = vdotq_s32(q5bytes_1, q8bytes.1);
|
|
sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32;
|
|
scales = scales.add(1);
|
|
|
|
let p2 = vdotq_s32(q5bytes_2, q8bytes.2);
|
|
let p3 = vdotq_s32(q5bytes_3, q8bytes.3);
|
|
sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32;
|
|
scales = scales.add(1);
|
|
}
|
|
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
|
|
}
|
|
}
|
|
Ok(sumf)
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
|
|
if n % QK_K != 0 {
|
|
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
|
}
|
|
let mut sumf = 0f32;
|
|
let mut utmp = [0u32; 4];
|
|
let mut scales = [0u8; 16];
|
|
const KMASK1: u32 = 0x3f3f3f3f;
|
|
const KMASK2: u32 = 0x0f0f0f0f;
|
|
const KMASK3: u32 = 0x03030303;
|
|
|
|
unsafe {
|
|
let m4b = vdupq_n_u8(0xF);
|
|
|
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
|
let d = y.d * x.d.to_f32();
|
|
let dmin = y.d * x.dmin.to_f32();
|
|
|
|
let q8sums = vpaddq_s16(
|
|
vld1q_s16(y.bsums.as_ptr()),
|
|
vld1q_s16(y.bsums.as_ptr().add(8)),
|
|
);
|
|
|
|
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
|
|
|
let mins8 = vld1_u32(
|
|
[
|
|
utmp[1] & KMASK1,
|
|
((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4),
|
|
]
|
|
.as_ptr(),
|
|
);
|
|
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
|
utmp[0] &= KMASK1;
|
|
|
|
let mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
|
|
let prod = vaddq_s32(
|
|
vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)),
|
|
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)),
|
|
);
|
|
sumf -= dmin * vaddvq_s32(prod) as f32;
|
|
|
|
LittleEndian::write_u32_into(&utmp, &mut scales);
|
|
|
|
let mut q4 = x.qs.as_ptr();
|
|
let mut q8 = y.qs.as_ptr();
|
|
|
|
let mut sumi1 = 0i32;
|
|
let mut sumi2 = 0i32;
|
|
|
|
for j in 0..QK_K / 64 {
|
|
let q4bits = vld1q_u8_x2(q4);
|
|
q4 = q4.add(32);
|
|
let q8bytes = vld1q_s8_x2(q8);
|
|
q8 = q8.add(32);
|
|
let q4bytes = int8x16x2_t(
|
|
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
|
|
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
|
|
);
|
|
let p0 = vdotq_s32(q4bytes.0, q8bytes.0);
|
|
let p1 = vdotq_s32(q4bytes.1, q8bytes.1);
|
|
sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32;
|
|
|
|
let q8bytes = vld1q_s8_x2(q8);
|
|
q8 = q8.add(32);
|
|
let q4bytes = int8x16x2_t(
|
|
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
|
|
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
|
|
);
|
|
let p2 = vdotq_s32(q4bytes.0, q8bytes.0);
|
|
let p3 = vdotq_s32(q4bytes.1, q8bytes.1);
|
|
sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32;
|
|
}
|
|
sumf += d * (sumi1 + sumi2) as f32;
|
|
}
|
|
}
|
|
Ok(sumf)
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> {
|
|
if n % QK_K != 0 {
|
|
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
|
|
}
|
|
let mut sumf = 0f32;
|
|
let mut utmp = [0u32; 4];
|
|
let mut aux = [0u32; 3];
|
|
const KMASK1: u32 = 0x03030303;
|
|
const KMASK2: u32 = 0x0f0f0f0f;
|
|
|
|
unsafe {
|
|
let m3b = vdupq_n_u8(0x3);
|
|
let m0 = vdupq_n_u8(1);
|
|
let m1 = vshlq_n_u8(m0, 1);
|
|
let m2 = vshlq_n_u8(m0, 2);
|
|
let m3 = vshlq_n_u8(m0, 3);
|
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
|
let d = y.d * x.d.to_f32();
|
|
let mut q3 = x.qs.as_ptr();
|
|
let qh = x.hmask.as_ptr();
|
|
let mut q8 = y.qs.as_ptr();
|
|
|
|
let mut qhbits = vld1q_u8_x2(qh);
|
|
|
|
let mut isum = 0i32;
|
|
|
|
// Set up scales
|
|
LittleEndian::read_u32_into(&x.scales, &mut aux);
|
|
|
|
utmp[3] = ((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4);
|
|
utmp[2] = ((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4);
|
|
utmp[1] = (aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4);
|
|
utmp[0] = (aux[0] & KMASK2) | ((aux[2] & KMASK1) << 4);
|
|
|
|
let mut scale = utmp.as_mut_ptr() as *mut i8;
|
|
for j in 0..16 {
|
|
*scale.add(j) -= 32i8
|
|
}
|
|
|
|
for j in 0..QK_K / 128 {
|
|
let q3bits = vld1q_u8_x2(q3);
|
|
q3 = q3.add(32);
|
|
let q8bytes_1 = vld1q_s8_x4(q8);
|
|
q8 = q8.add(64);
|
|
let q8bytes_2 = vld1q_s8_x4(q8);
|
|
q8 = q8.add(64);
|
|
|
|
let q3h_0 = vshlq_n_u8(vbicq_u8(m0, qhbits.0), 2);
|
|
let q3h_1 = vshlq_n_u8(vbicq_u8(m0, qhbits.1), 2);
|
|
let q3h_2 = vshlq_n_u8(vbicq_u8(m1, qhbits.0), 1);
|
|
let q3h_3 = vshlq_n_u8(vbicq_u8(m1, qhbits.1), 1);
|
|
|
|
let q3bytes_0 = vsubq_s8(
|
|
vreinterpretq_s8_u8(vandq_u8(q3bits.0, m3b)),
|
|
vreinterpretq_s8_u8(q3h_0),
|
|
);
|
|
let q3bytes_1 = vsubq_s8(
|
|
vreinterpretq_s8_u8(vandq_u8(q3bits.1, m3b)),
|
|
vreinterpretq_s8_u8(q3h_1),
|
|
);
|
|
let q3bytes_2 = vsubq_s8(
|
|
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 2), m3b)),
|
|
vreinterpretq_s8_u8(q3h_2),
|
|
);
|
|
let q3bytes_3 = vsubq_s8(
|
|
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 2), m3b)),
|
|
vreinterpretq_s8_u8(q3h_3),
|
|
);
|
|
|
|
let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0);
|
|
let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1);
|
|
let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2);
|
|
let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3);
|
|
isum += vaddvq_s32(p0) * *scale as i32
|
|
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
|
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
|
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
|
scale = scale.add(4);
|
|
|
|
let q3h_0 = vbicq_u8(m2, qhbits.0);
|
|
let q3h_1 = vbicq_u8(m2, qhbits.1);
|
|
let q3h_2 = vshrq_n_u8(vbicq_u8(m3, qhbits.0), 1);
|
|
let q3h_3 = vshrq_n_u8(vbicq_u8(m3, qhbits.1), 1);
|
|
|
|
let q3bytes_0 = vsubq_s8(
|
|
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 4), m3b)),
|
|
vreinterpretq_s8_u8(q3h_0),
|
|
);
|
|
let q3bytes_1 = vsubq_s8(
|
|
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 4), m3b)),
|
|
vreinterpretq_s8_u8(q3h_1),
|
|
);
|
|
let q3bytes_2 = vsubq_s8(
|
|
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 6), m3b)),
|
|
vreinterpretq_s8_u8(q3h_2),
|
|
);
|
|
let q3bytes_3 = vsubq_s8(
|
|
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 6), m3b)),
|
|
vreinterpretq_s8_u8(q3h_3),
|
|
);
|
|
|
|
let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0);
|
|
let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1);
|
|
let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2);
|
|
let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3);
|
|
isum += vaddvq_s32(p0) * *scale as i32
|
|
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
|
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
|
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
|
scale = scale.add(4);
|
|
|
|
if j == 0 {
|
|
qhbits.0 = vshrq_n_u8(qhbits.0, 4);
|
|
qhbits.1 = vshrq_n_u8(qhbits.1, 4);
|
|
}
|
|
}
|
|
sumf += d * isum as f32;
|
|
}
|
|
}
|
|
Ok(sumf)
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
|
|
if n % QK_K != 0 {
|
|
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
|
}
|
|
let mut sumf = 0f32;
|
|
let mut aux = [0u8; 16];
|
|
|
|
unsafe {
|
|
let m3 = vdupq_n_u8(0x3);
|
|
let m4 = vdupq_n_u8(0xF);
|
|
|
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
|
let d = y.d * x.d.to_f32();
|
|
let dmin = -y.d * x.dmin.to_f32();
|
|
|
|
let mut q2 = x.qs.as_ptr();
|
|
let mut q8 = y.qs.as_ptr();
|
|
let sc = x.scales.as_ptr();
|
|
|
|
let mins_and_scales = vld1q_u8(sc);
|
|
let scales = vandq_u8(mins_and_scales, m4);
|
|
vst1q_u8(aux.as_mut_ptr(), scales);
|
|
|
|
let mins = vshrq_n_u8(mins_and_scales, 4);
|
|
let q8sums = vld1q_s16_x2(y.bsums.as_ptr());
|
|
let mins16 = int16x8x2_t(
|
|
vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))),
|
|
vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins))),
|
|
);
|
|
let s0 = vaddq_s32(
|
|
vmull_s16(vget_low_s16(mins16.0), vget_low_s16(q8sums.0)),
|
|
vmull_s16(vget_high_s16(mins16.0), vget_high_s16(q8sums.0)),
|
|
);
|
|
let s1 = vaddq_s32(
|
|
vmull_s16(vget_low_s16(mins16.1), vget_low_s16(q8sums.1)),
|
|
vmull_s16(vget_high_s16(mins16.1), vget_high_s16(q8sums.1)),
|
|
);
|
|
sumf += dmin * vaddvq_s32(vaddq_s32(s0, s1)) as f32;
|
|
|
|
let mut isum = 0i32;
|
|
let mut is = 0usize;
|
|
|
|
// TODO: dotprod
|
|
for _j in 0..QK_K / 128 {
|
|
let q2bits = vld1q_u8_x2(q2);
|
|
q2 = q2.add(32);
|
|
|
|
let q8bytes = vld1q_s8_x2(q8);
|
|
q8 = q8.add(32);
|
|
let mut q2bytes = int8x16x2_t(
|
|
vreinterpretq_s8_u8(vandq_u8(q2bits.0, m3)),
|
|
vreinterpretq_s8_u8(vandq_u8(q2bits.1, m3)),
|
|
);
|
|
isum += multiply_accum_with_scale(&aux, is, 0, q2bytes, q8bytes);
|
|
|
|
let q8bytes = vld1q_s8_x2(q8);
|
|
q8 = q8.add(32);
|
|
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 2), m3));
|
|
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 2), m3));
|
|
isum += multiply_accum_with_scale(&aux, is, 2, q2bytes, q8bytes);
|
|
|
|
let q8bytes = vld1q_s8_x2(q8);
|
|
q8 = q8.add(32);
|
|
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 4), m3));
|
|
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 4), m3));
|
|
isum += multiply_accum_with_scale(&aux, is, 4, q2bytes, q8bytes);
|
|
|
|
let q8bytes = vld1q_s8_x2(q8);
|
|
q8 = q8.add(32);
|
|
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 6), m3));
|
|
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 6), m3));
|
|
isum += multiply_accum_with_scale(&aux, is, 6, q2bytes, q8bytes);
|
|
|
|
is += 8;
|
|
}
|
|
sumf += d * isum as f32;
|
|
}
|
|
}
|
|
Ok(sumf)
|
|
}
|
|
|
|
#[inline(always)]
|
|
unsafe fn multiply_accum_with_scale(
|
|
aux: &[u8; 16],
|
|
is: usize,
|
|
index: usize,
|
|
q2bytes: int8x16x2_t,
|
|
q8bytes: int8x16x2_t,
|
|
) -> i32 {
|
|
let p1 = vdotq_s32(q2bytes.0, q8bytes.0);
|
|
let p2 = vdotq_s32(q2bytes.1, q8bytes.1);
|
|
vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32
|
|
}
|