mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
AVX version for the q8-0 multiplications. (#598)
This commit is contained in:
@ -56,7 +56,6 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
|||||||
}
|
}
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
// Generic implementation.
|
|
||||||
let mut acc = _mm256_setzero_ps();
|
let mut acc = _mm256_setzero_ps();
|
||||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||||
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
|
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||||
@ -71,6 +70,25 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||||
|
let qk = QK8_0;
|
||||||
|
if n % QK8_0 != 0 {
|
||||||
|
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||||
|
}
|
||||||
|
unsafe {
|
||||||
|
let mut acc = _mm256_setzero_ps();
|
||||||
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||||
|
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||||
|
let bx = _mm256_loadu_si256(x.qs.as_ptr() as *const __m256i);
|
||||||
|
let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i);
|
||||||
|
let q = mul_sum_i8_pairs_float(bx, by);
|
||||||
|
acc = _mm256_fmadd_ps(d, q, acc);
|
||||||
|
}
|
||||||
|
Ok(hsum_float_8(acc))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const K_SHUFFLE: [u8; 128] = [
|
const K_SHUFFLE: [u8; 128] = [
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
|
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
|
||||||
4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
|
4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
|
||||||
|
@ -421,7 +421,11 @@ impl GgmlType for BlockQ8_0 {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(unreachable_code)]
|
||||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||||
|
#[cfg(target_feature = "avx")]
|
||||||
|
return super::avx::vec_dot_q8_0_q8_0(n, xs, ys);
|
||||||
|
|
||||||
let qk = QK8_0;
|
let qk = QK8_0;
|
||||||
if n % QK8_0 != 0 {
|
if n % QK8_0 != 0 {
|
||||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||||
|
Reference in New Issue
Block a user