simd128 optimized q8_0 vecdot (#972)

* wasm/simd128 version of the quantized q8_0 vecdot.

* Add the missing conversion.
This commit is contained in:
Laurent Mazare
2023-09-27 11:03:20 +01:00
committed by GitHub
parent 29bd6b2979
commit e59784e353
3 changed files with 54 additions and 0 deletions

View File

@ -606,6 +606,9 @@ impl GgmlType for BlockQ8_0 {
#[cfg(target_feature = "neon")] #[cfg(target_feature = "neon")]
return super::neon::vec_dot_q8_0_q8_0(n, xs, ys); return super::neon::vec_dot_q8_0_q8_0(n, xs, ys);
#[cfg(target_feature = "simd128")]
return super::simd128::vec_dot_q8_0_q8_0(n, xs, ys);
let qk = QK8_0; let qk = QK8_0;
if n % QK8_0 != 0 { if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")

View File

@ -7,6 +7,8 @@ pub mod gguf_file;
pub mod k_quants; pub mod k_quants;
#[cfg(target_feature = "neon")] #[cfg(target_feature = "neon")]
pub mod neon; pub mod neon;
#[cfg(target_feature = "simd128")]
pub mod simd128;
pub mod utils; pub mod utils;
pub use k_quants::GgmlType; pub use k_quants::GgmlType;

View File

@ -0,0 +1,49 @@
use super::k_quants::{BlockQ8_0, QK8_0};
use crate::Result;
use half::f16;
use core::arch::wasm32::*;
#[inline(always)]
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
if nb % 2 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let x1 = i16x8_load_extend_i8x8(x.qs.as_ptr());
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
let sum_xy = i32x4_dot_i16x8(x1, y1);
let x2 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(8));
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
let x3 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(16));
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
let x4 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(24));
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
let sum_xy = f32x4_convert_i32x4(sum_xy);
// f32x4_relaxed_madd is nightly only.
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
let scaled = f32x4_mul(sum_xy, d);
acc = f32x4_add(acc, scaled)
}
let res = f32x4_extract_lane::<0>(acc)
+ f32x4_extract_lane::<1>(acc)
+ f32x4_extract_lane::<2>(acc)
+ f32x4_extract_lane::<3>(acc);
Ok(res)
}
}