diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs index e5fa058d..37a4e3ba 100644 --- a/candle-core/src/quantized/avx.rs +++ b/candle-core/src/quantized/avx.rs @@ -25,10 +25,10 @@ pub(crate) unsafe fn mul_sum_us8_pairs_float(ax: __m256i, sy: __m256i) -> __m256 #[inline(always)] pub(crate) unsafe fn hsum_float_8(x: __m256) -> f32 { - let mut res = _mm256_extractf128_ps(x, 1); - res = _mm_add_ps(res, _mm256_castps256_ps128(x)); - res = _mm_add_ps(res, _mm_movehl_ps(res, res)); - res = _mm_add_ss(res, _mm_movehdup_ps(res)); + let res = _mm256_extractf128_ps(x, 1); + let res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + let res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + let res = _mm_add_ss(res, _mm_movehdup_ps(res)); _mm_cvtss_f32(res) } @@ -226,7 +226,7 @@ unsafe fn mm256_set_m128i(a: __m128i, b: __m128i) -> __m256i { _mm256_insertf128_si256(_mm256_castsi128_si256(b), a, 1) } -#[cfg_attr(not(debug_assertions), inline(always))] +#[inline(always)] pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 { crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}") @@ -282,18 +282,22 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res let q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3); let q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3); - let mut p0 = _mm256_maddubs_epi16(q2_0, q8_0); - let mut p1 = _mm256_maddubs_epi16(q2_1, q8_1); - let mut p2 = _mm256_maddubs_epi16(q2_2, q8_2); - let mut p3 = _mm256_maddubs_epi16(q2_3, q8_3); + let p0 = _mm256_maddubs_epi16(q2_0, q8_0); + let p1 = _mm256_maddubs_epi16(q2_1, q8_1); + let p2 = _mm256_maddubs_epi16(q2_2, q8_2); + let p3 = _mm256_maddubs_epi16(q2_3, q8_3); - p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(0)), p0); - p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(1)), p1); - p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(2)), p2); - p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(3)), p3); + let p0 = + _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(0)), p0); + let p1 = + _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(1)), p1); + let p2 = + _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(2)), p2); + let p3 = + _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(3)), p3); - p0 = _mm256_add_epi32(p0, p1); - p2 = _mm256_add_epi32(p2, p3); + let p0 = _mm256_add_epi32(p0, p1); + let p2 = _mm256_add_epi32(p2, p3); sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2)); } @@ -304,7 +308,7 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res } } -#[cfg_attr(not(debug_assertions), inline(always))] +#[inline(always)] pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 { crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}") @@ -328,13 +332,13 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res let mut q8 = y.qs.as_ptr(); LittleEndian::read_u32_into(&x.scales, &mut aux); - let mut scales128 = _mm_set_epi32( + let scales128 = _mm_set_epi32( (((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4)) as i32, (((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4)) as i32, ((aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4)) as i32, ((aux[0] & KMASK2) | (((aux[2]) & KMASK1) << 4)) as i32, ); - scales128 = _mm_sub_epi8(scales128, m32); + let scales128 = _mm_sub_epi8(scales128, m32); let all_scales = _mm256_cvtepi8_epi16(scales128); let l_scales = _mm256_extracti128_si256(all_scales, 0); let h_scales = _mm256_extracti128_si256(all_scales, 1); @@ -346,7 +350,6 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res // high bit let hbits = _mm256_loadu_si256(x.hmask.as_ptr() as *const __m256i); - // integer accumulator let mut sumi = _mm256_setzero_si256(); for (j, scale) in scales.iter().enumerate() { @@ -354,83 +357,39 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res let q3bits = _mm256_loadu_si256(q3 as *const __m256i); q3 = q3.add(32); - // prepare low and high bits - //We hardcode the shifts here to avoid loading them into a seperate register + // Prepare low and high bits + // We hardcode the shifts here to avoid loading them into a seperate register let q3l_0 = _mm256_and_si256(q3bits, m3); let q3h_0 = if j == 0 { - _mm256_slli_epi16( - _mm256_srli_epi16( - _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), - 0, - ), - 2, - ) + _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0) } else { - _mm256_slli_epi16( - _mm256_srli_epi16( - _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 4)), - 4, - ), - 2, - ) + _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 4)), 4) }; + let q3h_0 = _mm256_slli_epi16(q3h_0, 2); let q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); let q3h_1 = if j == 0 { - _mm256_slli_epi16( - _mm256_srli_epi16( - _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 1)), - 1, - ), - 2, - ) + _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 1)), 1) } else { - _mm256_slli_epi16( - _mm256_srli_epi16( - _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 5)), - 5, - ), - 2, - ) + _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 5)), 5) }; + let q3h_1 = _mm256_slli_epi16(q3h_1, 2); let q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); let q3h_2 = if j == 0 { - _mm256_slli_epi16( - _mm256_srli_epi16( - _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 2)), - 2, - ), - 2, - ) + _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 2)), 2) } else { - _mm256_slli_epi16( - _mm256_srli_epi16( - _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 6)), - 6, - ), - 2, - ) + _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 6)), 6) }; + let q3h_2 = _mm256_slli_epi16(q3h_2, 2); let q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); let q3h_3 = if j == 0 { - _mm256_slli_epi16( - _mm256_srli_epi16( - _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 3)), - 3, - ), - 2, - ) + _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 3)), 3) } else { - _mm256_slli_epi16( - _mm256_srli_epi16( - _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 7)), - 7, - ), - 2, - ) + _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 7)), 7) }; + let q3h_3 = _mm256_slli_epi16(q3h_3, 2); // load Q8 quants let q8_0 = _mm256_loadu_si256(q8 as *const __m256i); @@ -442,37 +401,38 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res let q8_3 = _mm256_loadu_si256(q8 as *const __m256i); q8 = q8.add(32); - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we + // can use _mm256_maddubs_epi16, and then subtract. The high bit part has the 2 + // already subtracted (and so, it is zero if the high bit was not set, and 2 if the + // high bit was set) let q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); let q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); let q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); let q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); - let mut p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); - let mut p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); - let mut p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); - let mut p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); + let p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + let p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + let p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); + let p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - p16_2 = _mm256_sub_epi16(p16_2, q8s_2); - p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + let p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + let p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + let p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + let p16_3 = _mm256_sub_epi16(p16_3, q8s_3); // multiply with scales - p16_0 = + let p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(0)), p16_0); - p16_1 = + let p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(1)), p16_1); - p16_2 = + let p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(2)), p16_2); - p16_3 = + let p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(3)), p16_3); // accumulate - p16_0 = _mm256_add_epi32(p16_0, p16_1); - p16_2 = _mm256_add_epi32(p16_2, p16_3); + let p16_0 = _mm256_add_epi32(p16_0, p16_1); + let p16_2 = _mm256_add_epi32(p16_2, p16_3); sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); } @@ -567,7 +527,7 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res } } -#[cfg_attr(not(debug_assertions), inline(always))] +#[inline(always)] pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 { crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}") @@ -664,11 +624,11 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res let q8_1 = _mm256_loadu_si256(q8 as *const __m256i); q8 = q8.add(32); - let mut p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); - let mut p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); + let p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); + let p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); - p16_0 = _mm256_madd_epi16(scale_0, p16_0); - p16_1 = _mm256_madd_epi16(scale_1, p16_1); + let p16_0 = _mm256_madd_epi16(scale_0, p16_0); + let p16_1 = _mm256_madd_epi16(scale_1, p16_1); sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); }