Move the avx specific bits to a separate file. (#481)

This commit is contained in:
Laurent Mazare
2023-08-17 09:01:06 +01:00
committed by GitHub
parent f708efb19c
commit d99cac3ec3
4 changed files with 119 additions and 116 deletions

View File

@ -0,0 +1,72 @@
use super::k_quants::{BlockQ4_0, BlockQ8_0, QK8_0};
use crate::Result;
use half::f16;
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
#[inline(always)]
pub(crate) unsafe fn sum_i16_pairs_float(x: __m256i) -> __m256 {
let ones = _mm256_set1_epi16(1);
let summed_pairs = _mm256_madd_epi16(ones, x);
_mm256_cvtepi32_ps(summed_pairs)
}
#[inline(always)]
pub(crate) unsafe fn mul_sum_us8_pairs_float(ax: __m256i, sy: __m256i) -> __m256 {
let dot = _mm256_maddubs_epi16(ax, sy);
sum_i16_pairs_float(dot)
}
#[inline(always)]
pub(crate) unsafe fn hsum_float_8(x: __m256) -> f32 {
let mut res = _mm256_extractf128_ps(x, 1);
res = _mm_add_ps(res, _mm256_castps256_ps128(x));
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
res = _mm_add_ss(res, _mm_movehdup_ps(res));
_mm_cvtss_f32(res)
}
#[inline(always)]
pub(crate) unsafe fn bytes_from_nibbles_32(rsi: *const u8) -> __m256i {
let tmp = _mm_loadu_si128(rsi as *const __m128i);
let bytes = _mm256_insertf128_si256::<1>(_mm256_castsi128_si256(tmp), _mm_srli_epi16(tmp, 4));
let low_mask = _mm256_set1_epi8(0xF);
_mm256_and_si256(low_mask, bytes)
}
#[inline(always)]
pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
let ax = _mm256_sign_epi8(x, x);
let sy = _mm256_sign_epi8(y, x);
mul_sum_us8_pairs_float(ax, sy)
}
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
unsafe {
// Generic implementation.
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
let bx = bytes_from_nibbles_32(x.qs.as_ptr());
let off = _mm256_set1_epi8(8);
let bx = _mm256_sub_epi8(bx, off);
let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i);
let q = mul_sum_i8_pairs_float(bx, by);
acc = _mm256_fmadd_ps(d, q, acc);
}
Ok(hsum_float_8(acc))
}
}

View File

@ -1,8 +1,3 @@
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
use super::GgmlDType;
use crate::Result;
use half::f16;
@ -38,73 +33,73 @@ pub trait GgmlType: Sized + Clone {
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ4_0 {
d: f16,
qs: [u8; QK4_0 / 2],
pub(crate) d: f16,
pub(crate) qs: [u8; QK4_0 / 2],
}
const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == 18);
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ4_1 {
d: f16,
m: f16,
qs: [u8; QK4_1 / 2],
pub(crate) d: f16,
pub(crate) m: f16,
pub(crate) qs: [u8; QK4_1 / 2],
}
const _: () = assert!(std::mem::size_of::<BlockQ4_1>() == 20);
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ5_0 {
d: f16,
qh: [u8; 4],
qs: [u8; QK5_0 / 2],
pub(crate) d: f16,
pub(crate) qh: [u8; 4],
pub(crate) qs: [u8; QK5_0 / 2],
}
const _: () = assert!(std::mem::size_of::<BlockQ5_0>() == 22);
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ5_1 {
d: f16,
m: f16,
qh: [u8; 4],
qs: [u8; QK5_1 / 2],
pub(crate) d: f16,
pub(crate) m: f16,
pub(crate) qh: [u8; 4],
pub(crate) qs: [u8; QK5_1 / 2],
}
const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24);
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ8_0 {
d: f16,
qs: [i8; QK8_0],
pub(crate) d: f16,
pub(crate) qs: [i8; QK8_0],
}
const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ8_1 {
d: f16,
s: f16,
qs: [u8; QK8_1],
pub(crate) d: f16,
pub(crate) s: f16,
pub(crate) qs: [u8; QK8_1],
}
const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ2K {
scales: [u8; QK_K / 16],
qs: [u8; QK_K / 4],
d: f16,
dmin: f16,
pub(crate) scales: [u8; QK_K / 16],
pub(crate) qs: [u8; QK_K / 4],
pub(crate) d: f16,
pub(crate) dmin: f16,
}
const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::<BlockQ2K>());
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ3K {
hmask: [u8; QK_K / 8],
qs: [u8; QK_K / 4],
scales: [u8; 12],
d: f16,
pub(crate) hmask: [u8; QK_K / 8],
pub(crate) qs: [u8; QK_K / 4],
pub(crate) scales: [u8; 12],
pub(crate) d: f16,
}
const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ3K>());
@ -112,21 +107,21 @@ const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82
#[repr(C)]
pub struct BlockQ4K {
d: f16,
dmin: f16,
scales: [u8; K_SCALE_SIZE],
qs: [u8; QK_K / 2],
pub(crate) d: f16,
pub(crate) dmin: f16,
pub(crate) scales: [u8; K_SCALE_SIZE],
pub(crate) qs: [u8; QK_K / 2],
}
const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::<BlockQ4K>());
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ5K {
d: f16,
dmin: f16,
scales: [u8; K_SCALE_SIZE],
qh: [u8; QK_K / 8],
qs: [u8; QK_K / 2],
pub(crate) d: f16,
pub(crate) dmin: f16,
pub(crate) scales: [u8; K_SCALE_SIZE],
pub(crate) qh: [u8; QK_K / 8],
pub(crate) qs: [u8; QK_K / 2],
}
const _: () =
assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::<BlockQ5K>());
@ -134,19 +129,19 @@ const _: () =
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ6K {
ql: [u8; QK_K / 2],
qh: [u8; QK_K / 4],
scales: [i8; QK_K / 16],
d: f16,
pub(crate) ql: [u8; QK_K / 2],
pub(crate) qh: [u8; QK_K / 4],
pub(crate) scales: [i8; QK_K / 16],
pub(crate) d: f16,
}
const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::<BlockQ6K>());
#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ8K {
d: f32,
qs: [i8; QK_K],
bsums: [i16; QK_K / 16],
pub(crate) d: f32,
pub(crate) qs: [i8; QK_K],
pub(crate) bsums: [i16; QK_K / 16],
}
const _: () = assert!(4 + QK_K + QK_K / 16 * 2 == std::mem::size_of::<BlockQ8K>());
@ -625,48 +620,6 @@ impl GgmlType for BlockQ8K {
}
}
#[cfg(target_feature = "avx")]
#[inline(always)]
unsafe fn sum_i16_pairs_float(x: __m256i) -> __m256 {
let ones = _mm256_set1_epi16(1);
let summed_pairs = _mm256_madd_epi16(ones, x);
_mm256_cvtepi32_ps(summed_pairs)
}
#[cfg(target_feature = "avx")]
#[inline(always)]
unsafe fn mul_sum_us8_pairs_float(ax: __m256i, sy: __m256i) -> __m256 {
let dot = _mm256_maddubs_epi16(ax, sy);
sum_i16_pairs_float(dot)
}
#[cfg(target_feature = "avx")]
#[inline(always)]
unsafe fn hsum_float_8(x: __m256) -> f32 {
let mut res = _mm256_extractf128_ps(x, 1);
res = _mm_add_ps(res, _mm256_castps256_ps128(x));
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
res = _mm_add_ss(res, _mm_movehdup_ps(res));
_mm_cvtss_f32(res)
}
#[cfg(target_feature = "avx")]
#[inline(always)]
unsafe fn bytes_from_nibbles_32(rsi: *const u8) -> __m256i {
let tmp = _mm_loadu_si128(rsi as *const __m128i);
let bytes = _mm256_insertf128_si256::<1>(_mm256_castsi128_si256(tmp), _mm_srli_epi16(tmp, 4));
let low_mask = _mm256_set1_epi8(0xF);
_mm256_and_si256(low_mask, bytes)
}
#[cfg(target_feature = "avx")]
#[inline(always)]
unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
let ax = _mm256_sign_epi8(x, x);
let sy = _mm256_sign_epi8(y, x);
mul_sum_us8_pairs_float(ax, sy)
}
impl GgmlType for BlockQ4_0 {
const DTYPE: GgmlDType = GgmlDType::Q4_0;
const BLCK_SIZE: usize = QK4_0;
@ -732,36 +685,12 @@ impl GgmlType for BlockQ4_0 {
Ok(())
}
#[cfg(target_feature = "avx")]
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
unsafe {
// Generic implementation.
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
let bx = bytes_from_nibbles_32(x.qs.as_ptr());
let off = _mm256_set1_epi8(8);
let bx = _mm256_sub_epi8(bx, off);
let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i);
let q = mul_sum_i8_pairs_float(bx, by);
acc = _mm256_fmadd_ps(d, q, acc);
}
Ok(hsum_float_8(acc))
}
}
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L2361C10-L2361C122
#[cfg(not(target_feature = "avx"))]
#[allow(unreachable_code)]
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
#[cfg(target_feature = "avx")]
return super::avx::vec_dot_q4_0_q8_0(n, xs, ys);
let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {

View File

@ -1,5 +1,7 @@
use crate::{Device, Result, Shape, Tensor};
#[cfg(target_feature = "avx")]
pub mod avx;
pub mod ggml_file;
pub mod k_quants;