From 11d3687cc655f8f79d856342a5539a9274e96df4 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 3 Oct 2023 15:29:48 +0100 Subject: [PATCH] Simd128 optimized q8k vecdot. (#1026) --- candle-core/src/quantized/k_quants.rs | 3 +++ candle-core/src/quantized/simd128.rs | 30 ++++++++++++++++++++++++ candle-examples/examples/whisper/main.rs | 2 +- 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 7567c446..b140131e 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1764,6 +1764,9 @@ impl GgmlType for BlockQ8K { #[cfg(target_feature = "neon")] return super::neon::vec_dot_q8k_q8k(n, xs, ys); + #[cfg(target_feature = "simd128")] + return super::simd128::vec_dot_q8k_q8k(n, xs, ys); + Self::vec_dot_unopt(n, xs, ys) } diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs index cc26ac10..687399c2 100644 --- a/candle-core/src/quantized/simd128.rs +++ b/candle-core/src/quantized/simd128.rs @@ -395,3 +395,33 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res Ok(sums) } } + +#[inline(always)] +pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result { + let qk = QK_K; + if n % QK_K != 0 { + crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}") + } + + unsafe { + let mut acc = f32x4_splat(0.0f32); + for (xs, ys) in xs.iter().zip(ys.iter()) { + let x_qs = xs.qs.as_ptr(); + let y_qs = ys.qs.as_ptr(); + let mut sumi = i32x4_splat(0); + for j in (0..QK_K).step_by(8) { + let xs = i16x8_load_extend_i8x8(x_qs.add(j)); + let ys = i16x8_load_extend_i8x8(y_qs.add(j)); + let sum_xy = i32x4_dot_i16x8(xs, ys); + sumi = i32x4_add(sumi, sum_xy) + } + let d = f32x4_splat(xs.d * ys.d); + acc = f32x4_add(acc, f32x4_mul(f32x4_convert_i32x4(sumi), d)) + } + let res = f32x4_extract_lane::<0>(acc) + + f32x4_extract_lane::<1>(acc) + + f32x4_extract_lane::<2>(acc) + + f32x4_extract_lane::<3>(acc); + Ok(res) + } +} diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 5d4b624e..07247451 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -493,7 +493,7 @@ fn main() -> Result<()> { ( repo.get(&format!("config-{ext}.json"))?, repo.get(&format!("tokenizer-{ext}.json"))?, - repo.get(&format!("model-{ext}-q40.gguf"))?, + repo.get(&format!("model-{ext}-q80.gguf"))?, ) } else { (