mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
neon optimized q8k multiplication. (#1021)
* neon optimized q8k multiplication. * Bugfixes. * simdification.
This commit is contained in:
@ -1756,14 +1756,18 @@ impl GgmlType for BlockQ8K {
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
#[allow(unreachable_code)]
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q8k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
let qk = QK_K;
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
// Generic implementation.
|
||||
|
@ -148,6 +148,35 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
let mut sumf = 0f32;
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
unsafe {
|
||||
let mut sum_i = vdupq_n_s32(0);
|
||||
let scale = xs.d * ys.d;
|
||||
let xs = xs.qs.as_ptr();
|
||||
let ys = ys.qs.as_ptr();
|
||||
for i in (0..QK_K).step_by(16) {
|
||||
let xs = vld1q_s8(xs.add(i));
|
||||
let ys = vld1q_s8(ys.add(i));
|
||||
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
|
||||
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
|
||||
|
||||
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
|
||||
sum_i = vaddq_s32(sum_i, xy)
|
||||
}
|
||||
sumf += vaddvq_s32(sum_i) as f32 * scale
|
||||
}
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
|
Reference in New Issue
Block a user