From d99cac3ec38a52bd81cc72059259729e7272e490 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 17 Aug 2023 09:01:06 +0100 Subject: [PATCH] Move the avx specific bits to a separate file. (#481) --- candle-core/src/quantized/avx.rs | 72 ++++++++ candle-core/src/quantized/k_quants.rs | 161 +++++------------- candle-core/src/quantized/mod.rs | 2 + .../examples/{ggml => quantized}/main.rs | 0 4 files changed, 119 insertions(+), 116 deletions(-) create mode 100644 candle-core/src/quantized/avx.rs rename candle-examples/examples/{ggml => quantized}/main.rs (100%) diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs new file mode 100644 index 00000000..27bad26d --- /dev/null +++ b/candle-core/src/quantized/avx.rs @@ -0,0 +1,72 @@ +use super::k_quants::{BlockQ4_0, BlockQ8_0, QK8_0}; +use crate::Result; +use half::f16; + +#[cfg(target_arch = "x86")] +use core::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use core::arch::x86_64::*; + +#[inline(always)] +pub(crate) unsafe fn sum_i16_pairs_float(x: __m256i) -> __m256 { + let ones = _mm256_set1_epi16(1); + let summed_pairs = _mm256_madd_epi16(ones, x); + _mm256_cvtepi32_ps(summed_pairs) +} + +#[inline(always)] +pub(crate) unsafe fn mul_sum_us8_pairs_float(ax: __m256i, sy: __m256i) -> __m256 { + let dot = _mm256_maddubs_epi16(ax, sy); + sum_i16_pairs_float(dot) +} + +#[inline(always)] +pub(crate) unsafe fn hsum_float_8(x: __m256) -> f32 { + let mut res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + _mm_cvtss_f32(res) +} + +#[inline(always)] +pub(crate) unsafe fn bytes_from_nibbles_32(rsi: *const u8) -> __m256i { + let tmp = _mm_loadu_si128(rsi as *const __m128i); + let bytes = _mm256_insertf128_si256::<1>(_mm256_castsi128_si256(tmp), _mm_srli_epi16(tmp, 4)); + let low_mask = _mm256_set1_epi8(0xF); + _mm256_and_si256(low_mask, bytes) +} + +#[inline(always)] +pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 { + let ax = _mm256_sign_epi8(x, x); + let sy = _mm256_sign_epi8(y, x); + mul_sum_us8_pairs_float(ax, sy) +} + +#[inline(always)] +pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result { + let qk = QK8_0; + let nb = n / qk; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") + } + if nb % 2 != 0 { + crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even") + } + + unsafe { + // Generic implementation. + let mut acc = _mm256_setzero_ps(); + for (x, y) in xs.iter().zip(ys.iter()) { + let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d)); + let bx = bytes_from_nibbles_32(x.qs.as_ptr()); + let off = _mm256_set1_epi8(8); + let bx = _mm256_sub_epi8(bx, off); + let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i); + let q = mul_sum_i8_pairs_float(bx, by); + acc = _mm256_fmadd_ps(d, q, acc); + } + Ok(hsum_float_8(acc)) + } +} diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 27d2ee3a..8616e375 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1,8 +1,3 @@ -#[cfg(target_arch = "x86")] -use core::arch::x86::*; -#[cfg(target_arch = "x86_64")] -use core::arch::x86_64::*; - use super::GgmlDType; use crate::Result; use half::f16; @@ -38,73 +33,73 @@ pub trait GgmlType: Sized + Clone { #[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ4_0 { - d: f16, - qs: [u8; QK4_0 / 2], + pub(crate) d: f16, + pub(crate) qs: [u8; QK4_0 / 2], } const _: () = assert!(std::mem::size_of::() == 18); #[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ4_1 { - d: f16, - m: f16, - qs: [u8; QK4_1 / 2], + pub(crate) d: f16, + pub(crate) m: f16, + pub(crate) qs: [u8; QK4_1 / 2], } const _: () = assert!(std::mem::size_of::() == 20); #[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ5_0 { - d: f16, - qh: [u8; 4], - qs: [u8; QK5_0 / 2], + pub(crate) d: f16, + pub(crate) qh: [u8; 4], + pub(crate) qs: [u8; QK5_0 / 2], } const _: () = assert!(std::mem::size_of::() == 22); #[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ5_1 { - d: f16, - m: f16, - qh: [u8; 4], - qs: [u8; QK5_1 / 2], + pub(crate) d: f16, + pub(crate) m: f16, + pub(crate) qh: [u8; 4], + pub(crate) qs: [u8; QK5_1 / 2], } const _: () = assert!(std::mem::size_of::() == 24); #[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ8_0 { - d: f16, - qs: [i8; QK8_0], + pub(crate) d: f16, + pub(crate) qs: [i8; QK8_0], } const _: () = assert!(std::mem::size_of::() == 34); #[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ8_1 { - d: f16, - s: f16, - qs: [u8; QK8_1], + pub(crate) d: f16, + pub(crate) s: f16, + pub(crate) qs: [u8; QK8_1], } const _: () = assert!(std::mem::size_of::() == 36); #[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ2K { - scales: [u8; QK_K / 16], - qs: [u8; QK_K / 4], - d: f16, - dmin: f16, + pub(crate) scales: [u8; QK_K / 16], + pub(crate) qs: [u8; QK_K / 4], + pub(crate) d: f16, + pub(crate) dmin: f16, } const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::()); #[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ3K { - hmask: [u8; QK_K / 8], - qs: [u8; QK_K / 4], - scales: [u8; 12], - d: f16, + pub(crate) hmask: [u8; QK_K / 8], + pub(crate) qs: [u8; QK_K / 4], + pub(crate) scales: [u8; 12], + pub(crate) d: f16, } const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::()); @@ -112,21 +107,21 @@ const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::()); #[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ5K { - d: f16, - dmin: f16, - scales: [u8; K_SCALE_SIZE], - qh: [u8; QK_K / 8], - qs: [u8; QK_K / 2], + pub(crate) d: f16, + pub(crate) dmin: f16, + pub(crate) scales: [u8; K_SCALE_SIZE], + pub(crate) qh: [u8; QK_K / 8], + pub(crate) qs: [u8; QK_K / 2], } const _: () = assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::()); @@ -134,19 +129,19 @@ const _: () = #[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ6K { - ql: [u8; QK_K / 2], - qh: [u8; QK_K / 4], - scales: [i8; QK_K / 16], - d: f16, + pub(crate) ql: [u8; QK_K / 2], + pub(crate) qh: [u8; QK_K / 4], + pub(crate) scales: [i8; QK_K / 16], + pub(crate) d: f16, } const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::()); #[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ8K { - d: f32, - qs: [i8; QK_K], - bsums: [i16; QK_K / 16], + pub(crate) d: f32, + pub(crate) qs: [i8; QK_K], + pub(crate) bsums: [i16; QK_K / 16], } const _: () = assert!(4 + QK_K + QK_K / 16 * 2 == std::mem::size_of::()); @@ -625,48 +620,6 @@ impl GgmlType for BlockQ8K { } } -#[cfg(target_feature = "avx")] -#[inline(always)] -unsafe fn sum_i16_pairs_float(x: __m256i) -> __m256 { - let ones = _mm256_set1_epi16(1); - let summed_pairs = _mm256_madd_epi16(ones, x); - _mm256_cvtepi32_ps(summed_pairs) -} - -#[cfg(target_feature = "avx")] -#[inline(always)] -unsafe fn mul_sum_us8_pairs_float(ax: __m256i, sy: __m256i) -> __m256 { - let dot = _mm256_maddubs_epi16(ax, sy); - sum_i16_pairs_float(dot) -} - -#[cfg(target_feature = "avx")] -#[inline(always)] -unsafe fn hsum_float_8(x: __m256) -> f32 { - let mut res = _mm256_extractf128_ps(x, 1); - res = _mm_add_ps(res, _mm256_castps256_ps128(x)); - res = _mm_add_ps(res, _mm_movehl_ps(res, res)); - res = _mm_add_ss(res, _mm_movehdup_ps(res)); - _mm_cvtss_f32(res) -} - -#[cfg(target_feature = "avx")] -#[inline(always)] -unsafe fn bytes_from_nibbles_32(rsi: *const u8) -> __m256i { - let tmp = _mm_loadu_si128(rsi as *const __m128i); - let bytes = _mm256_insertf128_si256::<1>(_mm256_castsi128_si256(tmp), _mm_srli_epi16(tmp, 4)); - let low_mask = _mm256_set1_epi8(0xF); - _mm256_and_si256(low_mask, bytes) -} - -#[cfg(target_feature = "avx")] -#[inline(always)] -unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 { - let ax = _mm256_sign_epi8(x, x); - let sy = _mm256_sign_epi8(y, x); - mul_sum_us8_pairs_float(ax, sy) -} - impl GgmlType for BlockQ4_0 { const DTYPE: GgmlDType = GgmlDType::Q4_0; const BLCK_SIZE: usize = QK4_0; @@ -732,36 +685,12 @@ impl GgmlType for BlockQ4_0 { Ok(()) } - #[cfg(target_feature = "avx")] - fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - let qk = QK8_0; - let nb = n / qk; - if n % QK8_0 != 0 { - crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") - } - if nb % 2 != 0 { - crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even") - } - - unsafe { - // Generic implementation. - let mut acc = _mm256_setzero_ps(); - for (x, y) in xs.iter().zip(ys.iter()) { - let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d)); - let bx = bytes_from_nibbles_32(x.qs.as_ptr()); - let off = _mm256_set1_epi8(8); - let bx = _mm256_sub_epi8(bx, off); - let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i); - let q = mul_sum_i8_pairs_float(bx, by); - acc = _mm256_fmadd_ps(d, q, acc); - } - Ok(hsum_float_8(acc)) - } - } - // https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L2361C10-L2361C122 - #[cfg(not(target_feature = "avx"))] + #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "avx")] + return super::avx::vec_dot_q4_0_q8_0(n, xs, ys); + let qk = QK8_0; let nb = n / qk; if n % QK8_0 != 0 { diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 52dddcf5..a0ed5b4d 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -1,5 +1,7 @@ use crate::{Device, Result, Shape, Tensor}; +#[cfg(target_feature = "avx")] +pub mod avx; pub mod ggml_file; pub mod k_quants; diff --git a/candle-examples/examples/ggml/main.rs b/candle-examples/examples/quantized/main.rs similarity index 100% rename from candle-examples/examples/ggml/main.rs rename to candle-examples/examples/quantized/main.rs