diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 602ea25f..180724a4 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1539,6 +1539,9 @@ impl GgmlType for BlockQ6K { #[cfg(target_feature = "neon")] return super::neon::vec_dot_q6k_q8k(n, xs, ys); + #[cfg(target_feature = "simd128")] + return super::simd128::vec_dot_q6k_q8k(n, xs, ys); + if n % QK_K != 0 { crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}") } diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs index bddeda7e..cc26ac10 100644 --- a/candle-core/src/quantized/simd128.rs +++ b/candle-core/src/quantized/simd128.rs @@ -1,4 +1,4 @@ -use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ8K, BlockQ8_0, QK8_0, QK_K}; +use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K}; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; use half::f16; @@ -272,3 +272,126 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res Ok(sums) } } + +#[inline(always)] +pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}") + } + + let mut aux8 = [0i8; QK_K]; + unsafe { + let mut sums = f32x4_splat(0f32); + + for (x, y) in xs.iter().zip(ys.iter()) { + let q4 = &x.ql; + let qh = &x.qh; + let q8 = &y.qs; + let mut aux32 = f32x4_splat(0f32); + + for j in (0..QK_K).step_by(128) { + let aux8 = aux8.as_mut_ptr().add(j); + let q4 = &q4.as_ptr().add(j / 2); + let qh = &qh.as_ptr().add(j / 4); + for l in (0..32).step_by(16) { + // aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8; + let a8 = v128_or( + v128_and(v128_load(q4.add(l) as *const v128), u8x16_splat(0xF)), + u8x16_shl( + v128_and(v128_load(qh.add(l) as *const v128), u8x16_splat(3)), + 4, + ), + ); + let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32)); + let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32)); + v128_store( + aux8.add(l) as *mut v128, + i8x16_narrow_i16x8(a8_low, a8_high), + ); + + // aux8[l + 32] = + // (((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8; + let a8 = v128_or( + v128_and(v128_load(q4.add(l + 32) as *const v128), u8x16_splat(0xF)), + u8x16_shl( + v128_and( + u8x16_shr(v128_load(qh.add(l) as *const v128), 2), + u8x16_splat(3), + ), + 4, + ), + ); + let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32)); + let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32)); + v128_store( + aux8.add(l + 32) as *mut v128, + i8x16_narrow_i16x8(a8_low, a8_high), + ); + + // aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8; + let a8 = v128_or( + u8x16_shr(v128_load(q4.add(l) as *const v128), 4), + u8x16_shl( + v128_and( + u8x16_shr(v128_load(qh.add(l) as *const v128), 4), + u8x16_splat(3), + ), + 4, + ), + ); + let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32)); + let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32)); + v128_store( + aux8.add(l + 64) as *mut v128, + i8x16_narrow_i16x8(a8_low, a8_high), + ); + + // aux8[l + 96] = + // (((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8; + let a8 = v128_or( + u8x16_shr(v128_load(q4.add(l + 32) as *const v128), 4), + u8x16_shl( + v128_and( + u8x16_shr(v128_load(qh.add(l) as *const v128), 6), + u8x16_splat(3), + ), + 4, + ), + ); + let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32)); + let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32)); + v128_store( + aux8.add(l + 96) as *mut v128, + i8x16_narrow_i16x8(a8_low, a8_high), + ); + } + } + + for (j, &scale) in x.scales.iter().enumerate() { + let scale = f32x4_splat(scale as f32); + for offset in [0, 8] { + let aux16 = i16x8_mul( + i16x8_load_extend_i8x8(q8.as_ptr().add(16 * j + offset)), + i16x8_load_extend_i8x8(aux8.as_ptr().add(16 * j + offset)), + ); + aux32 = f32x4_add( + aux32, + f32x4_mul(f32x4_convert_i32x4(i32x4_extend_low_i16x8(aux16)), scale), + ); + aux32 = f32x4_add( + aux32, + f32x4_mul(f32x4_convert_i32x4(i32x4_extend_high_i16x8(aux16)), scale), + ); + } + } + + let d = f32x4_splat(x.d.to_f32() * y.d); + sums = f32x4_add(sums, f32x4_mul(aux32, d)); + } + let sums = f32x4_extract_lane::<0>(sums) + + f32x4_extract_lane::<1>(sums) + + f32x4_extract_lane::<2>(sums) + + f32x4_extract_lane::<3>(sums); + Ok(sums) + } +}