Avoid some mutable variables (take 2). (#554)

* Avoid some mutable variables (take 2).

* Fix.
This commit is contained in:
Laurent Mazare
2023-08-22 18:51:20 +01:00
committed by GitHub
parent cc22d4db20
commit 07067b01dc
2 changed files with 29 additions and 37 deletions

View File

@ -4,7 +4,9 @@ pub(super) fn nearest_int(v: f32) -> i32 {
v.round() as i32
}
/// Validates that the input and output are the right size and returns an iterator which maps each input region `xs` to its corresponding output block in `ys`. Each output region is guaranteed to be `T::BLCK_SIZE` long.
/// Validates that the input and output are the right size and returns an iterator which maps each
/// input region `xs` to its corresponding output block in `ys`. Each output region is guaranteed
/// to be `T::BLCK_SIZE` long.
pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
xs: &'b [f32],
ys: &'a mut [T],
@ -23,7 +25,9 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
Ok(ys.iter_mut().zip(xs.chunks_exact(block_size)).collect())
}
/// Validates that the input and output are the right size and returns an iterator which maps each input block `xs` to its corresponding output region in `ys`. Each output region is guaranteed to be `T::BLCK_SIZE` long.
/// Validates that the input and output are the right size and returns an iterator which maps each
/// input block `xs` to its corresponding output region in `ys`. Each output region is guaranteed
/// to be `T::BLCK_SIZE` long.
pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
xs: &'a [T],
ys: &'b mut [f32],
@ -174,7 +178,7 @@ pub(super) unsafe fn make_qx_quants(
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
let l = i32::max(-nmax, i32::min(nmax - 1, l));
let l = l.clamp(-nmax, nmax - 1);
let w = if weight_type == 1 { x * x } else { 1. };
let l = l as f32;
sumlx += w * x * l;
@ -198,7 +202,7 @@ pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32)
let n = x.len();
let mut l = vec![0; n];
// Get min/max
let mut min = *x
let min = *x
.iter()
.take(n)
.min_by(|a, b| a.total_cmp(b))
@ -211,9 +215,7 @@ pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32)
}
// Ensure min <= 0.0
if min > 0.0 {
min = 0.0;
}
let mut min = min.min(0.);
// Compute scale and inverse scale
let mut iscale = nmax as f32 / (max - min);
@ -225,8 +227,7 @@ pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32)
let mut did_change = false;
for (i, value) in x.iter().enumerate().take(n) {
let mut li = nearest_int(iscale * (value - min));
li = li.clamp(0, nmax);
let li = nearest_int(iscale * (value - min)).clamp(0, nmax);
let clamped_li = li as u8;
if clamped_li != l[i] {
l[i] = clamped_li;
@ -280,8 +281,8 @@ pub(super) fn make_q3_quants(x: &[f32], nmax: i32, do_rmse: bool) -> f32 {
let mut sumlx = 0.0;
let mut suml2 = 0.0;
for i in 0..n {
let mut li = (iscale * x[i]).round() as i32;
li = li.clamp(-nmax, nmax - 1);
let li = (iscale * x[i]).round() as i32;
let li = li.clamp(-nmax, nmax - 1);
l[i] = li as i8;
let w = x[i] * x[i];
sumlx += w * x[i] * li as f32;
@ -318,9 +319,8 @@ pub(super) fn make_q3_quants(x: &[f32], nmax: i32, do_rmse: bool) -> f32 {
return sumlx / suml2;
}
for i in 0..n {
let mut li = (iscale * x[i]).round() as i32;
li = li.clamp(-nmax, nmax - 1);
l[i] = (li + nmax) as i8;
let li = (iscale * x[i]).round() as i32;
l[i] = (li.clamp(-nmax, nmax - 1) + nmax) as i8;
}
1.0 / iscale
}