From 6559eae72cd55939f477f6d9ba45c228823c5b06 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 25 Aug 2023 18:21:37 +0100 Subject: [PATCH] Avoid some transmutes. (#607) --- candle-core/src/quantized/k_quants.rs | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 3aefa5df..cdea2434 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -4,6 +4,7 @@ use super::utils::{ }; use super::GgmlDType; use crate::Result; +use byteorder::{ByteOrder, LittleEndian}; use half::f16; use rayon::prelude::*; @@ -308,7 +309,7 @@ impl GgmlType for BlockQ5_0 { let nb = k / QK5_0; for i in 0..nb { let d = xs[i].d.to_f32(); - let qh: u32 = unsafe { std::mem::transmute_copy(&xs[i].qh) }; + let qh: u32 = LittleEndian::read_u32(&xs[i].qh); for j in 0..(QK5_0 / 2) { let xh_0 = (((qh >> j) << 4) & 0x10) as u8; @@ -349,7 +350,7 @@ impl GgmlType for BlockQ5_1 { for i in 0..nb { let d = xs[i].d.to_f32(); let m = xs[i].m.to_f32(); - let qh: u32 = unsafe { std::mem::transmute_copy(&xs[i].qh) }; + let qh: u32 = LittleEndian::read_u32(&xs[i].qh); for j in 0..(QK5_1 / 2) { let xh_0 = (((qh >> j) << 4) & 0x10) as u8; @@ -719,10 +720,7 @@ impl GgmlType for BlockQ3K { a = &mut aux8[..]; - let aux_raw = unsafe { - std::mem::transmute::<&mut [u8; 12], &mut [u32; 3]>(&mut x.scales.clone()) - }; - auxs[0..3].copy_from_slice(aux_raw); + LittleEndian::read_u32_into(&x.scales, &mut auxs[0..3]); let tmp = auxs[2]; auxs[2] = ((auxs[0] >> 4) & KMASK2) | (((tmp >> 4) & KMASK1) << 4); @@ -852,10 +850,7 @@ impl GgmlType for BlockQ3K { for (block, y) in group_for_dequantization(xs, ys)? { //Reconstruct the scales let mut aux = [0; 4]; - let aux_raw = unsafe { - std::mem::transmute::<&mut [u8; 12], &mut [u32; 3]>(&mut block.scales.clone()) - }; - aux[0..3].copy_from_slice(aux_raw); + LittleEndian::read_u32_into(&block.scales, &mut aux[0..3]); let tmp = aux[2]; aux[2] = ((aux[0] >> 4) & KMASK2) | (((tmp >> 4) & KMASK1) << 4);