mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add a quantized variant of llama2.c (#1197)
* Add a quantized variant of llama2.c * Clippy fixes.
This commit is contained in:
@ -94,28 +94,18 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
let nb = n / QK8_0;
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
|
||||
}
|
||||
unsafe {
|
||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||
let mut sumv1 = vdupq_n_f32(0.0f32);
|
||||
for i in (0..nb).step_by(2) {
|
||||
for i in 0..nb {
|
||||
let x0 = &xs[i];
|
||||
let x1 = &xs[i + 1];
|
||||
let y0 = &ys[i];
|
||||
let y1 = &ys[i + 1];
|
||||
|
||||
let x0_0 = vld1q_s8(x0.qs.as_ptr());
|
||||
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
|
||||
let x1_0 = vld1q_s8(x1.qs.as_ptr());
|
||||
let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16));
|
||||
|
||||
// load y
|
||||
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
||||
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
let y1_0 = vld1q_s8(y1.qs.as_ptr());
|
||||
let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16));
|
||||
|
||||
// TODO dotprod once this is the intrinsics are.
|
||||
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
|
||||
@ -123,28 +113,16 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
|
||||
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
||||
|
||||
let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0));
|
||||
let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
|
||||
let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1));
|
||||
let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
|
||||
|
||||
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
||||
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
||||
let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
|
||||
let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
|
||||
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
vcvtq_f32_s32(vaddq_s32(p0, p1)),
|
||||
x0.d.to_f32() * y0.d.to_f32(),
|
||||
);
|
||||
sumv1 = vmlaq_n_f32(
|
||||
sumv1,
|
||||
vcvtq_f32_s32(vaddq_s32(p2, p3)),
|
||||
x1.d.to_f32() * y1.d.to_f32(),
|
||||
);
|
||||
}
|
||||
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
|
||||
Ok(vaddvq_f32(sumv0))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -61,10 +61,6 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
let nb = n / QK8_0;
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
|
Reference in New Issue
Block a user