mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Simd128 q2k vecdot (#982)
* Sketch the simd128 version of q2k vecdot. * Use a single accumulator.
This commit is contained in:
@ -687,6 +687,9 @@ impl GgmlType for BlockQ2K {
|
|||||||
#[cfg(target_feature = "neon")]
|
#[cfg(target_feature = "neon")]
|
||||||
return super::neon::vec_dot_q2k_q8k(n, xs, ys);
|
return super::neon::vec_dot_q2k_q8k(n, xs, ys);
|
||||||
|
|
||||||
|
#[cfg(target_feature = "simd128")]
|
||||||
|
return super::simd128::vec_dot_q2k_q8k(n, xs, ys);
|
||||||
|
|
||||||
if n % QK_K != 0 {
|
if n % QK_K != 0 {
|
||||||
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use super::k_quants::{BlockQ4K, BlockQ4_0, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
|
use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use byteorder::{ByteOrder, LittleEndian};
|
use byteorder::{ByteOrder, LittleEndian};
|
||||||
use half::f16;
|
use half::f16;
|
||||||
@ -99,6 +99,58 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||||
|
if n % QK_K != 0 {
|
||||||
|
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
||||||
|
}
|
||||||
|
let mut sumf = 0.0;
|
||||||
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||||
|
let mut q2: &[_] = &x.qs;
|
||||||
|
let mut q8: &[_] = &y.qs;
|
||||||
|
let sc = &x.scales;
|
||||||
|
|
||||||
|
let mut summs = 0;
|
||||||
|
for (bsum, scale) in y.bsums.iter().zip(sc) {
|
||||||
|
summs += *bsum as i32 * ((scale >> 4) as i32);
|
||||||
|
}
|
||||||
|
|
||||||
|
let dall = y.d * x.d.to_f32();
|
||||||
|
let dmin = y.d * x.dmin.to_f32();
|
||||||
|
|
||||||
|
let mut isum = 0;
|
||||||
|
let mut is = 0;
|
||||||
|
let mut d;
|
||||||
|
for _ in 0..(QK_K / 128) {
|
||||||
|
let mut shift = 0;
|
||||||
|
for _ in 0..4 {
|
||||||
|
d = (sc[is] & 0xF) as i32;
|
||||||
|
is += 1;
|
||||||
|
let mut isuml = 0;
|
||||||
|
for l in 0..16 {
|
||||||
|
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
|
||||||
|
}
|
||||||
|
isum += d * isuml;
|
||||||
|
d = (sc[is] & 0xF) as i32;
|
||||||
|
is += 1;
|
||||||
|
isuml = 0;
|
||||||
|
for l in 16..32 {
|
||||||
|
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
|
||||||
|
}
|
||||||
|
isum += d * isuml;
|
||||||
|
shift += 2;
|
||||||
|
// adjust the indexing
|
||||||
|
q8 = &q8[32..];
|
||||||
|
}
|
||||||
|
// adjust the indexing
|
||||||
|
q2 = &q2[32..];
|
||||||
|
}
|
||||||
|
sumf += dall * isum as f32 - dmin * summs as f32;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(sumf)
|
||||||
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
|
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||||
if n % QK_K != 0 {
|
if n % QK_K != 0 {
|
||||||
@ -115,7 +167,6 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
|
|||||||
|
|
||||||
let mut aux8: [u8; QK_K] = [0; QK_K];
|
let mut aux8: [u8; QK_K] = [0; QK_K];
|
||||||
let mut sums = f32x4_splat(0f32);
|
let mut sums = f32x4_splat(0f32);
|
||||||
let mut sumf = f32x4_splat(0f32);
|
|
||||||
unsafe {
|
unsafe {
|
||||||
for (y, x) in ys.iter().zip(xs.iter()) {
|
for (y, x) in ys.iter().zip(xs.iter()) {
|
||||||
let q4 = &x.qs;
|
let q4 = &x.qs;
|
||||||
@ -180,9 +231,8 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
|
|||||||
let dmin = x.dmin.to_f32() * y.d;
|
let dmin = x.dmin.to_f32() * y.d;
|
||||||
let dmin = f32x4_splat(dmin);
|
let dmin = f32x4_splat(dmin);
|
||||||
let sumi = f32x4_convert_i32x4(sumi);
|
let sumi = f32x4_convert_i32x4(sumi);
|
||||||
sumf = f32x4_add(sumf, f32x4_mul(sumi, dmin));
|
sums = f32x4_sub(sums, f32x4_mul(sumi, dmin));
|
||||||
}
|
}
|
||||||
let sums = f32x4_sub(sums, sumf);
|
|
||||||
let sums = f32x4_extract_lane::<0>(sums)
|
let sums = f32x4_extract_lane::<0>(sums)
|
||||||
+ f32x4_extract_lane::<1>(sums)
|
+ f32x4_extract_lane::<1>(sums)
|
||||||
+ f32x4_extract_lane::<2>(sums)
|
+ f32x4_extract_lane::<2>(sums)
|
||||||
|
Reference in New Issue
Block a user