From cf27b9b6368d0af086e107d1ce890b2993825282 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 22 Aug 2023 15:44:26 +0100 Subject: [PATCH] Avoid some mut in quantized functions. (#550) * Avoid a couple more 'let mut'. * Tweaks. --- candle-core/src/quantized/k_quants.rs | 41 +++++++++++---------------- candle-core/src/quantized/utils.rs | 28 +++++++++--------- 2 files changed, 30 insertions(+), 39 deletions(-) diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index bfc471a3..3e45bc6d 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -503,8 +503,7 @@ impl GgmlType for BlockQ2K { } let dm = block.dmin.to_f32() * (block.scales[j] >> 4) as f32; for ii in 0..16 { - let mut ll = nearest_int((x[16 * j + ii] + dm) / d); - ll = ll.clamp(0, 3); + let ll = nearest_int((x[16 * j + ii] + dm) / d).clamp(0, 3); big_l[16 * j + ii] = ll as u8; } } @@ -587,14 +586,14 @@ impl GgmlType for BlockQ3K { if max_scale != 0.0 { let iscale = -32.0 / max_scale; for (j, scale) in scales.iter().enumerate() { - let mut l_val = nearest_int(iscale * scale); - l_val = l_val.clamp(-32, 31) + 32; + let l_val = nearest_int(iscale * scale); + let l_val = l_val.clamp(-32, 31) + 32; if j < 8 { block.scales[j] = (l_val & 0xF) as u8; } else { block.scales[j - 8] |= ((l_val & 0xF) << 4) as u8; } - l_val >>= 4; + let l_val = l_val >> 4; block.scales[j % 4 + 8] |= (l_val << (2 * (j / 4))) as u8; } block.d = f16::from_f32(1.0 / iscale); @@ -614,9 +613,8 @@ impl GgmlType for BlockQ3K { let d = block.d.to_f32() * sc as f32; if d != 0.0 { for ii in 0..16 { - let mut l_val = nearest_int(x[16 * j + ii] / d); - l_val = l_val.clamp(-4, 3); - l[16 * j + ii] = (l_val + 4) as i8; + let l_val = nearest_int(x[16 * j + ii] / d); + l[16 * j + ii] = (l_val.clamp(-4, 3) + 4) as i8; } } } @@ -702,7 +700,7 @@ impl GgmlType for BlockQ3K { // 16 block finished => advance scale index is += 1; } - //32 block finished => increase shift and m + // 32 block finished => increase shift and m shift += 2; m <<= 1; } @@ -743,10 +741,8 @@ impl GgmlType for BlockQ4K { let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 }; for j in 0..QK_K / 32 { - let mut ls = nearest_int(inv_scale * scales[j]) as u8; - let mut lm = nearest_int(inv_min * mins[j]) as u8; - ls = std::cmp::min(63, ls); - lm = std::cmp::min(63, lm); + let ls = nearest_int(inv_scale * scales[j]).min(63) as u8; + let lm = nearest_int(inv_min * mins[j]).min(63) as u8; if j < 4 { block.scales[j] = ls; block.scales[j + 4] = lm; @@ -768,9 +764,8 @@ impl GgmlType for BlockQ4K { if d != 0.0 { let dm = block.dmin.to_f32() * m as f32; for ii in 0..32 { - let mut l_val = nearest_int((x[32 * j + ii] + dm) / d); - l_val = l_val.clamp(0, 15); - l[32 * j + ii] = l_val as u8; + let l_val = nearest_int((x[32 * j + ii] + dm) / d); + l[32 * j + ii] = l_val.clamp(0, 15) as u8; } } } @@ -791,10 +786,10 @@ impl GgmlType for BlockQ4K { let d = block.d.to_f32(); let min = block.dmin.to_f32(); let q = &block.qs; - let mut is = 0; let mut ys_index = 0; for j in (0..QK_K).step_by(64) { + let is = j * 2; let q = &q[j / 2..j / 2 + 32]; let (sc, m) = get_scale_min_k4(is, &block.scales); let d1 = d * sc as f32; @@ -810,7 +805,6 @@ impl GgmlType for BlockQ4K { y[ys_index] = d2 * (q >> 4) as f32 - m2; ys_index += 1; } - is += 2; } } Ok(()) @@ -848,10 +842,8 @@ impl GgmlType for BlockQ5K { }; let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 }; for j in 0..QK_K / 32 { - let mut ls = nearest_int(inv_scale * scales[j]) as u8; - let mut lm = nearest_int(inv_min * mins[j]) as u8; - ls = ls.min(63); - lm = lm.min(63); + let ls = nearest_int(inv_scale * scales[j]).min(63) as u8; + let lm = nearest_int(inv_min * mins[j]).min(63) as u8; if j < 4 { block.scales[j] = ls; block.scales[j + 4] = lm; @@ -873,9 +865,8 @@ impl GgmlType for BlockQ5K { } let dm = block.dmin.to_f32() * m as f32; for ii in 0..32 { - let mut ll = nearest_int((x[32 * j + ii] + dm) / d); - ll = ll.min(31).max(0); - l[32 * j + ii] = ll as u8; + let ll = nearest_int((x[32 * j + ii] + dm) / d); + l[32 * j + ii] = ll.clamp(0, 31) as u8; } } diff --git a/candle-core/src/quantized/utils.rs b/candle-core/src/quantized/utils.rs index fded9d61..edbffa35 100644 --- a/candle-core/src/quantized/utils.rs +++ b/candle-core/src/quantized/utils.rs @@ -4,7 +4,9 @@ pub(super) fn nearest_int(v: f32) -> i32 { v.round() as i32 } -/// Validates that the input and output are the right size and returns an iterator which maps each input region `xs` to its corresponding output block in `ys`. Each output region is guaranteed to be `T::BLCK_SIZE` long. +/// Validates that the input and output are the right size and returns an iterator which maps each +/// input region `xs` to its corresponding output block in `ys`. Each output region is guaranteed +/// to be `T::BLCK_SIZE` long. pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>( xs: &'b [f32], ys: &'a mut [T], @@ -23,7 +25,9 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>( Ok(ys.iter_mut().zip(xs.chunks_exact(block_size)).collect()) } -/// Validates that the input and output are the right size and returns an iterator which maps each input block `xs` to its corresponding output region in `ys`. Each output region is guaranteed to be `T::BLCK_SIZE` long. +/// Validates that the input and output are the right size and returns an iterator which maps each +/// input block `xs` to its corresponding output region in `ys`. Each output region is guaranteed +/// to be `T::BLCK_SIZE` long. pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>( xs: &'a [T], ys: &'b mut [f32], @@ -174,7 +178,7 @@ pub(super) unsafe fn make_qx_quants( for i in 0..n { let x = *x.add(i); let l = nearest_int(iscale * x); - let l = i32::max(-nmax, i32::min(nmax - 1, l)); + let l = l.clamp(-nmax, nmax - 1); let w = if weight_type == 1 { x * x } else { 1. }; let l = l as f32; sumlx += w * x * l; @@ -198,7 +202,7 @@ pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32) let n = x.len(); let mut l = vec![0; n]; // Get min/max - let mut min = *x + let min = *x .iter() .take(n) .min_by(|a, b| a.total_cmp(b)) @@ -211,9 +215,7 @@ pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32) } // Ensure min <= 0.0 - if min > 0.0 { - min = 0.0; - } + let mut min = min.min(0.); // Compute scale and inverse scale let mut iscale = nmax as f32 / (max - min); @@ -225,8 +227,7 @@ pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32) let mut did_change = false; for (i, value) in x.iter().enumerate().take(n) { - let mut li = nearest_int(iscale * (value - min)); - li = li.clamp(0, nmax); + let li = nearest_int(iscale * (value - min)).clamp(0, nmax); let clamped_li = li as u8; if clamped_li != l[i] { l[i] = clamped_li; @@ -280,8 +281,8 @@ pub(super) fn make_q3_quants(x: &[f32], nmax: i32, do_rmse: bool) -> f32 { let mut sumlx = 0.0; let mut suml2 = 0.0; for i in 0..n { - let mut li = (iscale * x[i]).round() as i32; - li = li.clamp(-nmax, nmax - 1); + let li = (iscale * x[i]).round() as i32; + let li = li.clamp(-nmax, nmax - 1); l[i] = li as i8; let w = x[i] * x[i]; sumlx += w * x[i] * li as f32; @@ -318,9 +319,8 @@ pub(super) fn make_q3_quants(x: &[f32], nmax: i32, do_rmse: bool) -> f32 { return sumlx / suml2; } for i in 0..n { - let mut li = (iscale * x[i]).round() as i32; - li = li.clamp(-nmax, nmax - 1); - l[i] = (li + nmax) as i8; + let li = (iscale * x[i]).round() as i32; + l[i] = (li.clamp(-nmax, nmax - 1) + nmax) as i8; } 1.0 / iscale }