Avoid some mutable variables (take 2). (#554)

* Avoid some mutable variables (take 2).

* Fix.
This commit is contained in:
Laurent Mazare
2023-08-22 18:51:20 +01:00
committed by GitHub
parent cc22d4db20
commit 07067b01dc
2 changed files with 29 additions and 37 deletions

View File

@ -503,8 +503,7 @@ impl GgmlType for BlockQ2K {
}
let dm = block.dmin.to_f32() * (block.scales[j] >> 4) as f32;
for ii in 0..16 {
let mut ll = nearest_int((x[16 * j + ii] + dm) / d);
ll = ll.clamp(0, 3);
let ll = nearest_int((x[16 * j + ii] + dm) / d).clamp(0, 3);
big_l[16 * j + ii] = ll as u8;
}
}
@ -587,14 +586,14 @@ impl GgmlType for BlockQ3K {
if max_scale != 0.0 {
let iscale = -32.0 / max_scale;
for (j, scale) in scales.iter().enumerate() {
let mut l_val = nearest_int(iscale * scale);
l_val = l_val.clamp(-32, 31) + 32;
let l_val = nearest_int(iscale * scale);
let l_val = l_val.clamp(-32, 31) + 32;
if j < 8 {
block.scales[j] = (l_val & 0xF) as u8;
} else {
block.scales[j - 8] |= ((l_val & 0xF) << 4) as u8;
}
l_val >>= 4;
let l_val = l_val >> 4;
block.scales[j % 4 + 8] |= (l_val << (2 * (j / 4))) as u8;
}
block.d = f16::from_f32(1.0 / iscale);
@ -614,9 +613,8 @@ impl GgmlType for BlockQ3K {
let d = block.d.to_f32() * sc as f32;
if d != 0.0 {
for ii in 0..16 {
let mut l_val = nearest_int(x[16 * j + ii] / d);
l_val = l_val.clamp(-4, 3);
l[16 * j + ii] = (l_val + 4) as i8;
let l_val = nearest_int(x[16 * j + ii] / d);
l[16 * j + ii] = (l_val.clamp(-4, 3) + 4) as i8;
}
}
}
@ -702,7 +700,7 @@ impl GgmlType for BlockQ3K {
// 16 block finished => advance scale index
is += 1;
}
//32 block finished => increase shift and m
// 32 block finished => increase shift and m
shift += 2;
m <<= 1;
}
@ -743,10 +741,8 @@ impl GgmlType for BlockQ4K {
let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 };
for j in 0..QK_K / 32 {
let mut ls = nearest_int(inv_scale * scales[j]) as u8;
let mut lm = nearest_int(inv_min * mins[j]) as u8;
ls = std::cmp::min(63, ls);
lm = std::cmp::min(63, lm);
let ls = nearest_int(inv_scale * scales[j]).min(63) as u8;
let lm = nearest_int(inv_min * mins[j]).min(63) as u8;
if j < 4 {
block.scales[j] = ls;
block.scales[j + 4] = lm;
@ -768,9 +764,8 @@ impl GgmlType for BlockQ4K {
if d != 0.0 {
let dm = block.dmin.to_f32() * m as f32;
for ii in 0..32 {
let mut l_val = nearest_int((x[32 * j + ii] + dm) / d);
l_val = l_val.clamp(0, 15);
l[32 * j + ii] = l_val as u8;
let l_val = nearest_int((x[32 * j + ii] + dm) / d);
l[32 * j + ii] = l_val.clamp(0, 15) as u8;
}
}
}
@ -848,10 +843,8 @@ impl GgmlType for BlockQ5K {
};
let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 };
for j in 0..QK_K / 32 {
let mut ls = nearest_int(inv_scale * scales[j]) as u8;
let mut lm = nearest_int(inv_min * mins[j]) as u8;
ls = ls.min(63);
lm = lm.min(63);
let ls = nearest_int(inv_scale * scales[j]).min(63) as u8;
let lm = nearest_int(inv_min * mins[j]).min(63) as u8;
if j < 4 {
block.scales[j] = ls;
block.scales[j + 4] = lm;
@ -873,9 +866,8 @@ impl GgmlType for BlockQ5K {
}
let dm = block.dmin.to_f32() * m as f32;
for ii in 0..32 {
let mut ll = nearest_int((x[32 * j + ii] + dm) / d);
ll = ll.min(31).max(0);
l[32 * j + ii] = ll as u8;
let ll = nearest_int((x[32 * j + ii] + dm) / d);
l[32 * j + ii] = ll.clamp(0, 31) as u8;
}
}