Small cleanups (avoid some possible mutations) (#670)

* More mut cleanup.

* Factor out some common bits.
This commit is contained in:
Laurent Mazare
2023-08-30 08:54:00 +01:00
committed by GitHub
parent a1a5ab8b0a
commit 9b25113393

View File

@ -25,10 +25,10 @@ pub(crate) unsafe fn mul_sum_us8_pairs_float(ax: __m256i, sy: __m256i) -> __m256
#[inline(always)] #[inline(always)]
pub(crate) unsafe fn hsum_float_8(x: __m256) -> f32 { pub(crate) unsafe fn hsum_float_8(x: __m256) -> f32 {
let mut res = _mm256_extractf128_ps(x, 1); let res = _mm256_extractf128_ps(x, 1);
res = _mm_add_ps(res, _mm256_castps256_ps128(x)); let res = _mm_add_ps(res, _mm256_castps256_ps128(x));
res = _mm_add_ps(res, _mm_movehl_ps(res, res)); let res = _mm_add_ps(res, _mm_movehl_ps(res, res));
res = _mm_add_ss(res, _mm_movehdup_ps(res)); let res = _mm_add_ss(res, _mm_movehdup_ps(res));
_mm_cvtss_f32(res) _mm_cvtss_f32(res)
} }
@ -226,7 +226,7 @@ unsafe fn mm256_set_m128i(a: __m128i, b: __m128i) -> __m256i {
_mm256_insertf128_si256(_mm256_castsi128_si256(b), a, 1) _mm256_insertf128_si256(_mm256_castsi128_si256(b), a, 1)
} }
#[cfg_attr(not(debug_assertions), inline(always))] #[inline(always)]
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> { pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 { if n % QK_K != 0 {
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}") crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
@ -282,18 +282,22 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res
let q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3); let q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
let q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3); let q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
let mut p0 = _mm256_maddubs_epi16(q2_0, q8_0); let p0 = _mm256_maddubs_epi16(q2_0, q8_0);
let mut p1 = _mm256_maddubs_epi16(q2_1, q8_1); let p1 = _mm256_maddubs_epi16(q2_1, q8_1);
let mut p2 = _mm256_maddubs_epi16(q2_2, q8_2); let p2 = _mm256_maddubs_epi16(q2_2, q8_2);
let mut p3 = _mm256_maddubs_epi16(q2_3, q8_3); let p3 = _mm256_maddubs_epi16(q2_3, q8_3);
p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(0)), p0); let p0 =
p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(1)), p1); _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(0)), p0);
p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(2)), p2); let p1 =
p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(3)), p3); _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(1)), p1);
let p2 =
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(2)), p2);
let p3 =
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(3)), p3);
p0 = _mm256_add_epi32(p0, p1); let p0 = _mm256_add_epi32(p0, p1);
p2 = _mm256_add_epi32(p2, p3); let p2 = _mm256_add_epi32(p2, p3);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2)); sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
} }
@ -304,7 +308,7 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res
} }
} }
#[cfg_attr(not(debug_assertions), inline(always))] #[inline(always)]
pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> { pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 { if n % QK_K != 0 {
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}") crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
@ -328,13 +332,13 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
let mut q8 = y.qs.as_ptr(); let mut q8 = y.qs.as_ptr();
LittleEndian::read_u32_into(&x.scales, &mut aux); LittleEndian::read_u32_into(&x.scales, &mut aux);
let mut scales128 = _mm_set_epi32( let scales128 = _mm_set_epi32(
(((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4)) as i32, (((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4)) as i32,
(((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4)) as i32, (((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4)) as i32,
((aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4)) as i32, ((aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4)) as i32,
((aux[0] & KMASK2) | (((aux[2]) & KMASK1) << 4)) as i32, ((aux[0] & KMASK2) | (((aux[2]) & KMASK1) << 4)) as i32,
); );
scales128 = _mm_sub_epi8(scales128, m32); let scales128 = _mm_sub_epi8(scales128, m32);
let all_scales = _mm256_cvtepi8_epi16(scales128); let all_scales = _mm256_cvtepi8_epi16(scales128);
let l_scales = _mm256_extracti128_si256(all_scales, 0); let l_scales = _mm256_extracti128_si256(all_scales, 0);
let h_scales = _mm256_extracti128_si256(all_scales, 1); let h_scales = _mm256_extracti128_si256(all_scales, 1);
@ -346,7 +350,6 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
// high bit // high bit
let hbits = _mm256_loadu_si256(x.hmask.as_ptr() as *const __m256i); let hbits = _mm256_loadu_si256(x.hmask.as_ptr() as *const __m256i);
// integer accumulator
let mut sumi = _mm256_setzero_si256(); let mut sumi = _mm256_setzero_si256();
for (j, scale) in scales.iter().enumerate() { for (j, scale) in scales.iter().enumerate() {
@ -354,83 +357,39 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
let q3bits = _mm256_loadu_si256(q3 as *const __m256i); let q3bits = _mm256_loadu_si256(q3 as *const __m256i);
q3 = q3.add(32); q3 = q3.add(32);
// prepare low and high bits // Prepare low and high bits
// We hardcode the shifts here to avoid loading them into a seperate register // We hardcode the shifts here to avoid loading them into a seperate register
let q3l_0 = _mm256_and_si256(q3bits, m3); let q3l_0 = _mm256_and_si256(q3bits, m3);
let q3h_0 = if j == 0 { let q3h_0 = if j == 0 {
_mm256_slli_epi16( _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
_mm256_srli_epi16(
_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)),
0,
),
2,
)
} else { } else {
_mm256_slli_epi16( _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 4)), 4)
_mm256_srli_epi16(
_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 4)),
4,
),
2,
)
}; };
let q3h_0 = _mm256_slli_epi16(q3h_0, 2);
let q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); let q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
let q3h_1 = if j == 0 { let q3h_1 = if j == 0 {
_mm256_slli_epi16( _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 1)), 1)
_mm256_srli_epi16(
_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 1)),
1,
),
2,
)
} else { } else {
_mm256_slli_epi16( _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 5)), 5)
_mm256_srli_epi16(
_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 5)),
5,
),
2,
)
}; };
let q3h_1 = _mm256_slli_epi16(q3h_1, 2);
let q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); let q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
let q3h_2 = if j == 0 { let q3h_2 = if j == 0 {
_mm256_slli_epi16( _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 2)), 2)
_mm256_srli_epi16(
_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 2)),
2,
),
2,
)
} else { } else {
_mm256_slli_epi16( _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 6)), 6)
_mm256_srli_epi16(
_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 6)),
6,
),
2,
)
}; };
let q3h_2 = _mm256_slli_epi16(q3h_2, 2);
let q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); let q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
let q3h_3 = if j == 0 { let q3h_3 = if j == 0 {
_mm256_slli_epi16( _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 3)), 3)
_mm256_srli_epi16(
_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 3)),
3,
),
2,
)
} else { } else {
_mm256_slli_epi16( _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 7)), 7)
_mm256_srli_epi16(
_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 7)),
7,
),
2,
)
}; };
let q3h_3 = _mm256_slli_epi16(q3h_3, 2);
// load Q8 quants // load Q8 quants
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i); let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
@ -442,37 +401,38 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i); let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32); q8 = q8.add(32);
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, // can use _mm256_maddubs_epi16, and then subtract. The high bit part has the 2
// and 2 if the high bit was set) // already subtracted (and so, it is zero if the high bit was not set, and 2 if the
// high bit was set)
let q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); let q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
let q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); let q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
let q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); let q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
let q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); let q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
let mut p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); let p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
let mut p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); let p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
let mut p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); let p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
let mut p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); let p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
p16_0 = _mm256_sub_epi16(p16_0, q8s_0); let p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
p16_1 = _mm256_sub_epi16(p16_1, q8s_1); let p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
p16_2 = _mm256_sub_epi16(p16_2, q8s_2); let p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
p16_3 = _mm256_sub_epi16(p16_3, q8s_3); let p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
// multiply with scales // multiply with scales
p16_0 = let p16_0 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(0)), p16_0); _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(0)), p16_0);
p16_1 = let p16_1 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(1)), p16_1); _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(1)), p16_1);
p16_2 = let p16_2 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(2)), p16_2); _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(2)), p16_2);
p16_3 = let p16_3 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(3)), p16_3); _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(3)), p16_3);
// accumulate // accumulate
p16_0 = _mm256_add_epi32(p16_0, p16_1); let p16_0 = _mm256_add_epi32(p16_0, p16_1);
p16_2 = _mm256_add_epi32(p16_2, p16_3); let p16_2 = _mm256_add_epi32(p16_2, p16_3);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
} }
@ -567,7 +527,7 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
} }
} }
#[cfg_attr(not(debug_assertions), inline(always))] #[inline(always)]
pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> { pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 { if n % QK_K != 0 {
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}") crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
@ -664,11 +624,11 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i); let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32); q8 = q8.add(32);
let mut p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); let p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
let mut p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); let p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
p16_0 = _mm256_madd_epi16(scale_0, p16_0); let p16_0 = _mm256_madd_epi16(scale_0, p16_0);
p16_1 = _mm256_madd_epi16(scale_1, p16_1); let p16_1 = _mm256_madd_epi16(scale_1, p16_1);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
} }