mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add vecdot for q6k-q8k. (#476)
* Add vecdot for q6k-q8k. * Add some testing for q8k. * Use QMatMul for the output layer.
This commit is contained in:
@ -462,8 +462,62 @@ impl GgmlType for BlockQ6K {
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
todo!()
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
|
||||
let mut aux8 = [0i8; QK_K];
|
||||
let mut aux16 = [0i16; 8];
|
||||
let mut sums = [0f32; 8];
|
||||
let mut aux32 = [0f32; 8];
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let q4 = &x.ql;
|
||||
let qh = &x.qh;
|
||||
let q8 = &y.qs;
|
||||
aux32.fill(0f32);
|
||||
|
||||
for j in (0..QK_K).step_by(128) {
|
||||
let aux8 = &mut aux8[j..];
|
||||
let q4 = &q4[j / 2..];
|
||||
let qh = &qh[j / 4..];
|
||||
for l in 0..32 {
|
||||
aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8;
|
||||
aux8[l + 32] =
|
||||
(((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8;
|
||||
aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8;
|
||||
aux8[l + 96] =
|
||||
(((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8;
|
||||
}
|
||||
}
|
||||
|
||||
for (j, &scale) in x.scales.iter().enumerate() {
|
||||
let scale = scale as f32;
|
||||
let q8 = &q8[16 * j..];
|
||||
let aux8 = &aux8[16 * j..];
|
||||
for l in 0..8 {
|
||||
aux16[l] = q8[l] as i16 * aux8[l] as i16;
|
||||
}
|
||||
for l in 0..8 {
|
||||
aux32[l] += scale * aux16[l] as f32
|
||||
}
|
||||
let q8 = &q8[8..];
|
||||
let aux8 = &aux8[8..];
|
||||
for l in 0..8 {
|
||||
aux16[l] = q8[l] as i16 * aux8[l] as i16;
|
||||
}
|
||||
for l in 0..8 {
|
||||
aux32[l] += scale * aux16[l] as f32
|
||||
}
|
||||
}
|
||||
|
||||
let d = x.d.to_f32() * y.d;
|
||||
for (sum, &a) in sums.iter_mut().zip(aux32.iter()) {
|
||||
*sum += a * d;
|
||||
}
|
||||
}
|
||||
Ok(sums.iter().sum())
|
||||
}
|
||||
|
||||
fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
|
||||
|
Reference in New Issue
Block a user