mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
710 lines
27 KiB
Rust
710 lines
27 KiB
Rust
use super::k_quants::{
|
|
BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K,
|
|
};
|
|
use crate::Result;
|
|
use byteorder::{ByteOrder, LittleEndian};
|
|
|
|
#[allow(unused_imports)]
|
|
#[cfg(target_arch = "arm")]
|
|
use core::arch::arm::*;
|
|
|
|
#[allow(unused_imports)]
|
|
#[cfg(target_arch = "aarch64")]
|
|
use core::arch::aarch64::*;
|
|
|
|
#[inline(always)]
|
|
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
|
let qk = QK8_0;
|
|
let nb = n / qk;
|
|
if n % QK8_0 != 0 {
|
|
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
|
}
|
|
|
|
unsafe {
|
|
let mut sumv0 = vdupq_n_f32(0.0f32);
|
|
for i in 0..nb {
|
|
let x0 = &xs[i];
|
|
let y0 = &ys[i];
|
|
|
|
let m4b = vdupq_n_u8(0x0F);
|
|
let s8b = vdupq_n_s8(0x8);
|
|
|
|
let v0_0 = vld1q_u8(x0.qs.as_ptr());
|
|
|
|
// 4-bit -> 8-bit
|
|
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
|
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
|
|
|
// sub 8
|
|
let v0_0ls = vsubq_s8(v0_0l, s8b);
|
|
let v0_0hs = vsubq_s8(v0_0h, s8b);
|
|
|
|
// load y
|
|
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
|
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
|
|
|
// TODO: Support dotprod when it's available outside of nightly.
|
|
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
|
|
let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
|
|
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
|
|
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
|
|
|
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
|
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
|
|
|
sumv0 = vmlaq_n_f32(
|
|
sumv0,
|
|
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
|
x0.d.to_f32() * y0.d.to_f32(),
|
|
);
|
|
}
|
|
Ok(vaddvq_f32(sumv0))
|
|
}
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
|
let qk = QK8_0;
|
|
if n % QK8_0 != 0 {
|
|
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
|
}
|
|
let nb = n / QK8_0;
|
|
unsafe {
|
|
let mut sumv0 = vdupq_n_f32(0.0f32);
|
|
for i in 0..nb {
|
|
let x0 = &xs[i];
|
|
let y0 = &ys[i];
|
|
|
|
let x0_0 = vld1q_s8(x0.qs.as_ptr());
|
|
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
|
|
|
|
// load y
|
|
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
|
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
|
|
|
// TODO dotprod once this is the intrinsics are.
|
|
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
|
|
let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
|
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
|
|
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
|
|
|
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
|
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
|
|
|
sumv0 = vmlaq_n_f32(
|
|
sumv0,
|
|
vcvtq_f32_s32(vaddq_s32(p0, p1)),
|
|
x0.d.to_f32() * y0.d.to_f32(),
|
|
);
|
|
}
|
|
Ok(vaddvq_f32(sumv0))
|
|
}
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
|
let qk = QK_K;
|
|
if n % QK_K != 0 {
|
|
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
|
}
|
|
|
|
let mut sumf = 0f32;
|
|
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
|
unsafe {
|
|
let mut sum_i = vdupq_n_s32(0);
|
|
let scale = xs.d * ys.d;
|
|
let xs = xs.qs.as_ptr();
|
|
let ys = ys.qs.as_ptr();
|
|
for i in (0..QK_K).step_by(16) {
|
|
let xs = vld1q_s8(xs.add(i));
|
|
let ys = vld1q_s8(ys.add(i));
|
|
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
|
|
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
|
|
|
|
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
|
|
sum_i = vaddq_s32(sum_i, xy)
|
|
}
|
|
sumf += vaddvq_s32(sum_i) as f32 * scale
|
|
}
|
|
}
|
|
Ok(sumf)
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
|
if n % QK_K != 0 {
|
|
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
|
}
|
|
let mut sum = 0f32;
|
|
unsafe {
|
|
let m4b = vdupq_n_u8(0xF);
|
|
|
|
let mone = vdupq_n_u8(3);
|
|
|
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
|
let d_all = x.d.to_f32();
|
|
|
|
let mut q6 = x.ql.as_ptr();
|
|
let mut qh = x.qh.as_ptr();
|
|
let mut q8 = y.qs.as_ptr();
|
|
|
|
let mut scale = x.scales.as_ptr();
|
|
|
|
let q8sums = vld1q_s16_x2(y.bsums.as_ptr());
|
|
let scales = vld1q_s8(scale);
|
|
let q6scales = int16x8x2_t(
|
|
vmovl_s8(vget_low_s8(scales)),
|
|
vmovl_s8(vget_high_s8(scales)),
|
|
);
|
|
|
|
let prod = vaddq_s32(
|
|
vaddq_s32(
|
|
vmull_s16(vget_low_s16(q8sums.0), vget_low_s16(q6scales.0)),
|
|
vmull_s16(vget_high_s16(q8sums.0), vget_high_s16(q6scales.0)),
|
|
),
|
|
vaddq_s32(
|
|
vmull_s16(vget_low_s16(q8sums.1), vget_low_s16(q6scales.1)),
|
|
vmull_s16(vget_high_s16(q8sums.1), vget_high_s16(q6scales.1)),
|
|
),
|
|
);
|
|
let isum_mins = vaddvq_s32(prod);
|
|
|
|
let mut isum = 0i32;
|
|
|
|
for _j in 0..QK_K / 128 {
|
|
let qhbits = vld1q_u8_x2(qh);
|
|
qh = qh.add(32);
|
|
let q6bits = vld1q_u8_x4(q6);
|
|
q6 = q6.add(64);
|
|
let q8bytes = vld1q_s8_x4(q8);
|
|
q8 = q8.add(64);
|
|
|
|
let q6h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4);
|
|
let q6h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4);
|
|
let shifted = vshrq_n_u8(qhbits.0, 2);
|
|
let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
|
let shifted = vshrq_n_u8(qhbits.1, 2);
|
|
let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
|
|
|
let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.0, m4b), q6h_0));
|
|
let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.1, m4b), q6h_1));
|
|
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
|
|
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
|
|
|
|
// TODO: dotprod
|
|
|
|
let p0 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
|
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
|
);
|
|
let p1 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
|
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
|
);
|
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
|
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
|
scale = scale.add(2);
|
|
|
|
let p2 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
|
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
|
);
|
|
let p3 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
|
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
|
);
|
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
|
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
|
scale = scale.add(2);
|
|
|
|
let q8bytes = vld1q_s8_x4(q8);
|
|
q8 = q8.add(64);
|
|
|
|
let shifted = vshrq_n_u8(qhbits.0, 4);
|
|
let q6h_0 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
|
let shifted = vshrq_n_u8(qhbits.1, 4);
|
|
let q6h_1 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
|
let shifted = vshrq_n_u8(qhbits.0, 6);
|
|
let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
|
let shifted = vshrq_n_u8(qhbits.1, 6);
|
|
let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
|
|
|
let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.0, 4), q6h_0));
|
|
let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.1, 4), q6h_1));
|
|
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
|
|
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
|
|
|
|
// TODO: dotprod case.
|
|
let p0 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
|
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
|
);
|
|
let p1 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
|
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
|
);
|
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
|
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
|
scale = scale.add(2);
|
|
|
|
let p2 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
|
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
|
);
|
|
let p3 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
|
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
|
);
|
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
|
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
|
scale = scale.add(2);
|
|
}
|
|
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
|
|
}
|
|
}
|
|
Ok(sum)
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> {
|
|
if n % QK_K != 0 {
|
|
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
|
|
}
|
|
let mut sumf = 0f32;
|
|
let mut utmp = [0u32; 4];
|
|
const KMASK1: u32 = 0x3f3f3f3f;
|
|
const KMASK2: u32 = 0x0f0f0f0f;
|
|
const KMASK3: u32 = 0x03030303;
|
|
|
|
unsafe {
|
|
let m4b = vdupq_n_u8(0xF);
|
|
let mone = vdupq_n_u8(1);
|
|
let mtwo = vdupq_n_u8(2);
|
|
|
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
|
let d = y.d * x.d.to_f32();
|
|
let dmin = y.d * x.dmin.to_f32();
|
|
|
|
let q8sums = vpaddq_s16(
|
|
vld1q_s16(y.bsums.as_ptr()),
|
|
vld1q_s16(y.bsums.as_ptr().add(8)),
|
|
);
|
|
|
|
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
|
|
|
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
|
|
let uaux = utmp[1] & KMASK1;
|
|
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
|
utmp[2] = uaux;
|
|
utmp[0] &= KMASK1;
|
|
|
|
let mins8 = vld1_u8((utmp.as_ptr() as *const u8).add(8));
|
|
let mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
|
|
let prod = vaddq_s32(
|
|
vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)),
|
|
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)),
|
|
);
|
|
let sumi_mins = vaddvq_s32(prod);
|
|
|
|
let mut scales = utmp.as_ptr() as *const u8;
|
|
|
|
let mut q5 = x.qs.as_ptr();
|
|
let mut q8 = y.qs.as_ptr();
|
|
|
|
let mut qhbits = vld1q_u8_x2(x.qh.as_ptr());
|
|
|
|
let mut sumi = 0i32;
|
|
|
|
for _j in 0..QK_K / 64 {
|
|
let q5bits = vld1q_u8_x2(q5);
|
|
q5 = q5.add(32);
|
|
let q8bytes = vld1q_s8_x4(q8);
|
|
q8 = q8.add(64);
|
|
|
|
let q5h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4);
|
|
let q5h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4);
|
|
let q5h_2 = vshlq_n_u8(vandq_u8(mtwo, qhbits.0), 3);
|
|
let q5h_3 = vshlq_n_u8(vandq_u8(mtwo, qhbits.1), 3);
|
|
qhbits.0 = vshrq_n_u8(qhbits.0, 2);
|
|
qhbits.1 = vshrq_n_u8(qhbits.1, 2);
|
|
|
|
let q5bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.0, m4b), q5h_0));
|
|
let q5bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.1, m4b), q5h_1));
|
|
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
|
|
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
|
|
|
|
// TODO: dotprod
|
|
|
|
let p0 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
|
|
vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
|
|
);
|
|
let p1 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
|
|
vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
|
|
);
|
|
sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
|
|
scales = scales.add(1);
|
|
|
|
let p2 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)),
|
|
vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)),
|
|
);
|
|
let p3 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
|
|
vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
|
|
);
|
|
sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
|
|
scales = scales.add(1);
|
|
}
|
|
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
|
|
}
|
|
}
|
|
Ok(sumf)
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
|
|
if n % QK_K != 0 {
|
|
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
|
}
|
|
let mut sumf = 0f32;
|
|
let mut utmp = [0u32; 4];
|
|
let mut scales = [0u8; 16];
|
|
const KMASK1: u32 = 0x3f3f3f3f;
|
|
const KMASK2: u32 = 0x0f0f0f0f;
|
|
const KMASK3: u32 = 0x03030303;
|
|
|
|
unsafe {
|
|
let m4b = vdupq_n_u8(0xF);
|
|
|
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
|
let d = y.d * x.d.to_f32();
|
|
let dmin = y.d * x.dmin.to_f32();
|
|
|
|
let q8sums = vpaddq_s16(
|
|
vld1q_s16(y.bsums.as_ptr()),
|
|
vld1q_s16(y.bsums.as_ptr().add(8)),
|
|
);
|
|
|
|
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
|
|
|
let mins8 = vld1_u32(
|
|
[
|
|
utmp[1] & KMASK1,
|
|
((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4),
|
|
]
|
|
.as_ptr(),
|
|
);
|
|
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
|
utmp[0] &= KMASK1;
|
|
|
|
let mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
|
|
let prod = vaddq_s32(
|
|
vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)),
|
|
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)),
|
|
);
|
|
sumf -= dmin * vaddvq_s32(prod) as f32;
|
|
|
|
LittleEndian::write_u32_into(&utmp, &mut scales);
|
|
|
|
let mut q4 = x.qs.as_ptr();
|
|
let mut q8 = y.qs.as_ptr();
|
|
|
|
let mut sumi1 = 0i32;
|
|
let mut sumi2 = 0i32;
|
|
|
|
for j in 0..QK_K / 64 {
|
|
let q4bits = vld1q_u8_x2(q4);
|
|
q4 = q4.add(32);
|
|
// TODO: dotprod
|
|
let q8bytes = vld1q_s8_x2(q8);
|
|
q8 = q8.add(32);
|
|
let q4bytes = int8x16x2_t(
|
|
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
|
|
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
|
|
);
|
|
let p0 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
|
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
|
);
|
|
let p1 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
|
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
|
);
|
|
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
|
|
|
|
let q8bytes = vld1q_s8_x2(q8);
|
|
q8 = q8.add(32);
|
|
let q4bytes = int8x16x2_t(
|
|
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
|
|
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
|
|
);
|
|
let p2 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
|
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
|
);
|
|
let p3 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
|
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
|
);
|
|
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
|
|
}
|
|
sumf += d * (sumi1 + sumi2) as f32;
|
|
}
|
|
}
|
|
Ok(sumf)
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> {
|
|
if n % QK_K != 0 {
|
|
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
|
|
}
|
|
let mut sumf = 0f32;
|
|
let mut utmp = [0u32; 4];
|
|
let mut aux = [0u32; 3];
|
|
const KMASK1: u32 = 0x03030303;
|
|
const KMASK2: u32 = 0x0f0f0f0f;
|
|
|
|
unsafe {
|
|
let m3b = vdupq_n_u8(0x3);
|
|
let m0 = vdupq_n_u8(1);
|
|
let m1 = vshlq_n_u8(m0, 1);
|
|
let m2 = vshlq_n_u8(m0, 2);
|
|
let m3 = vshlq_n_u8(m0, 3);
|
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
|
let d = y.d * x.d.to_f32();
|
|
let mut q3 = x.qs.as_ptr();
|
|
let qh = x.hmask.as_ptr();
|
|
let mut q8 = y.qs.as_ptr();
|
|
|
|
let mut qhbits = vld1q_u8_x2(qh);
|
|
|
|
let mut isum = 0i32;
|
|
|
|
// Set up scales
|
|
LittleEndian::read_u32_into(&x.scales, &mut aux);
|
|
|
|
utmp[3] = ((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4);
|
|
utmp[2] = ((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4);
|
|
utmp[1] = (aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4);
|
|
utmp[0] = (aux[0] & KMASK2) | ((aux[2] & KMASK1) << 4);
|
|
|
|
let mut scale = utmp.as_mut_ptr() as *mut i8;
|
|
for j in 0..16 {
|
|
*scale.add(j) -= 32i8
|
|
}
|
|
|
|
for j in 0..QK_K / 128 {
|
|
let q3bits = vld1q_u8_x2(q3);
|
|
q3 = q3.add(32);
|
|
let q8bytes_1 = vld1q_s8_x4(q8);
|
|
q8 = q8.add(64);
|
|
let q8bytes_2 = vld1q_s8_x4(q8);
|
|
q8 = q8.add(64);
|
|
|
|
let q3h_0 = vshlq_n_u8(vbicq_u8(m0, qhbits.0), 2);
|
|
let q3h_1 = vshlq_n_u8(vbicq_u8(m0, qhbits.1), 2);
|
|
let q3h_2 = vshlq_n_u8(vbicq_u8(m1, qhbits.0), 1);
|
|
let q3h_3 = vshlq_n_u8(vbicq_u8(m1, qhbits.1), 1);
|
|
|
|
let q3bytes_0 = vsubq_s8(
|
|
vreinterpretq_s8_u8(vandq_u8(q3bits.0, m3b)),
|
|
vreinterpretq_s8_u8(q3h_0),
|
|
);
|
|
let q3bytes_1 = vsubq_s8(
|
|
vreinterpretq_s8_u8(vandq_u8(q3bits.1, m3b)),
|
|
vreinterpretq_s8_u8(q3h_1),
|
|
);
|
|
let q3bytes_2 = vsubq_s8(
|
|
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 2), m3b)),
|
|
vreinterpretq_s8_u8(q3h_2),
|
|
);
|
|
let q3bytes_3 = vsubq_s8(
|
|
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 2), m3b)),
|
|
vreinterpretq_s8_u8(q3h_3),
|
|
);
|
|
|
|
// TODO: dotprod
|
|
let p0 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)),
|
|
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)),
|
|
);
|
|
let p1 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)),
|
|
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)),
|
|
);
|
|
let p2 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
|
|
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
|
|
);
|
|
let p3 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
|
|
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
|
|
);
|
|
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
|
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
|
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
|
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
|
scale = scale.add(4);
|
|
|
|
let q3h_0 = vbicq_u8(m2, qhbits.0);
|
|
let q3h_1 = vbicq_u8(m2, qhbits.1);
|
|
let q3h_2 = vshrq_n_u8(vbicq_u8(m3, qhbits.0), 1);
|
|
let q3h_3 = vshrq_n_u8(vbicq_u8(m3, qhbits.1), 1);
|
|
|
|
let q3bytes_0 = vsubq_s8(
|
|
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 4), m3b)),
|
|
vreinterpretq_s8_u8(q3h_0),
|
|
);
|
|
let q3bytes_1 = vsubq_s8(
|
|
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 4), m3b)),
|
|
vreinterpretq_s8_u8(q3h_1),
|
|
);
|
|
let q3bytes_2 = vsubq_s8(
|
|
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 6), m3b)),
|
|
vreinterpretq_s8_u8(q3h_2),
|
|
);
|
|
let q3bytes_3 = vsubq_s8(
|
|
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 6), m3b)),
|
|
vreinterpretq_s8_u8(q3h_3),
|
|
);
|
|
|
|
// TODO: dotprod
|
|
let p0 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)),
|
|
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)),
|
|
);
|
|
let p1 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)),
|
|
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)),
|
|
);
|
|
let p2 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
|
|
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
|
|
);
|
|
let p3 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
|
|
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
|
|
);
|
|
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
|
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
|
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
|
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
|
scale = scale.add(4);
|
|
|
|
if j == 0 {
|
|
qhbits.0 = vshrq_n_u8(qhbits.0, 4);
|
|
qhbits.1 = vshrq_n_u8(qhbits.1, 4);
|
|
}
|
|
}
|
|
sumf += d * isum as f32;
|
|
}
|
|
}
|
|
Ok(sumf)
|
|
}
|
|
|
|
#[inline(always)]
|
|
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
|
|
if n % QK_K != 0 {
|
|
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
|
}
|
|
let mut sumf = 0f32;
|
|
let mut aux = [0u8; 16];
|
|
|
|
unsafe {
|
|
let m3 = vdupq_n_u8(0x3);
|
|
let m4 = vdupq_n_u8(0xF);
|
|
|
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
|
let d = y.d * x.d.to_f32();
|
|
let dmin = -y.d * x.dmin.to_f32();
|
|
|
|
let mut q2 = x.qs.as_ptr();
|
|
let mut q8 = y.qs.as_ptr();
|
|
let sc = x.scales.as_ptr();
|
|
|
|
let mins_and_scales = vld1q_u8(sc);
|
|
let scales = vandq_u8(mins_and_scales, m4);
|
|
vst1q_u8(aux.as_mut_ptr(), scales);
|
|
|
|
let mins = vshrq_n_u8(mins_and_scales, 4);
|
|
let q8sums = vld1q_s16_x2(y.bsums.as_ptr());
|
|
let mins16 = int16x8x2_t(
|
|
vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))),
|
|
vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins))),
|
|
);
|
|
let s0 = vaddq_s32(
|
|
vmull_s16(vget_low_s16(mins16.0), vget_low_s16(q8sums.0)),
|
|
vmull_s16(vget_high_s16(mins16.0), vget_high_s16(q8sums.0)),
|
|
);
|
|
let s1 = vaddq_s32(
|
|
vmull_s16(vget_low_s16(mins16.1), vget_low_s16(q8sums.1)),
|
|
vmull_s16(vget_high_s16(mins16.1), vget_high_s16(q8sums.1)),
|
|
);
|
|
sumf += dmin * vaddvq_s32(vaddq_s32(s0, s1)) as f32;
|
|
|
|
let mut isum = 0i32;
|
|
let mut is = 0usize;
|
|
|
|
// TODO: dotprod
|
|
|
|
for _j in 0..QK_K / 128 {
|
|
let q2bits = vld1q_u8_x2(q2);
|
|
q2 = q2.add(32);
|
|
|
|
let q8bytes = vld1q_s8_x2(q8);
|
|
q8 = q8.add(32);
|
|
let mut q2bytes = int8x16x2_t(
|
|
vreinterpretq_s8_u8(vandq_u8(q2bits.0, m3)),
|
|
vreinterpretq_s8_u8(vandq_u8(q2bits.1, m3)),
|
|
);
|
|
isum += multiply_accum_with_scale(&aux, is, 0, q2bytes, q8bytes);
|
|
|
|
let q8bytes = vld1q_s8_x2(q8);
|
|
q8 = q8.add(32);
|
|
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 2), m3));
|
|
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 2), m3));
|
|
isum += multiply_accum_with_scale(&aux, is, 2, q2bytes, q8bytes);
|
|
|
|
let q8bytes = vld1q_s8_x2(q8);
|
|
q8 = q8.add(32);
|
|
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 4), m3));
|
|
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 4), m3));
|
|
isum += multiply_accum_with_scale(&aux, is, 4, q2bytes, q8bytes);
|
|
|
|
let q8bytes = vld1q_s8_x2(q8);
|
|
q8 = q8.add(32);
|
|
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 6), m3));
|
|
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 6), m3));
|
|
isum += multiply_accum_with_scale(&aux, is, 6, q2bytes, q8bytes);
|
|
|
|
is += 8;
|
|
}
|
|
sumf += d * isum as f32;
|
|
}
|
|
}
|
|
Ok(sumf)
|
|
}
|
|
|
|
#[inline(always)]
|
|
unsafe fn multiply_accum_with_scale(
|
|
aux: &[u8; 16],
|
|
is: usize,
|
|
index: usize,
|
|
q2bytes: int8x16x2_t,
|
|
q8bytes: int8x16x2_t,
|
|
) -> i32 {
|
|
let p1 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)),
|
|
vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)),
|
|
);
|
|
let p2 = vaddq_s16(
|
|
vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
|
|
vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
|
|
);
|
|
vaddvq_s16(p1) as i32 * aux[is + index] as i32
|
|
+ vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
|
|
}
|