Compare commits

...

3 Commits

Author SHA1 Message Date
ab12425bff Another tweak. 2024-09-26 10:14:53 +02:00
43a8cbe244 Tweaks. 2024-09-26 00:05:17 +02:00
46acac5a64 Cuda quantization padding fix. 2024-09-25 23:40:14 +02:00
2 changed files with 12 additions and 3 deletions

View File

@ -37,6 +37,12 @@ fn pad(p: usize, q: usize) -> usize {
ceil_div(p, q) * q
}
fn pad_for_alloc(p: usize) -> usize {
// Overallocate by q rather than just padding by q as this should pad the last row
// and we don't have enough information here to know how many elements to add :(
p + MATRIX_ROW_PADDING
}
fn quantize_q8_1(
src: &CudaView<f32>,
dst: &mut CudaSlice<u8>,
@ -444,8 +450,11 @@ impl QCudaStorage {
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
qcpu_storage.quantize(&src)?;
let data = qcpu_storage.data()?;
let data = self.device.htod_sync_copy(data.as_ref()).w()?;
self.data = data;
let mut dst = self.device.alloc_zeros::<u8>(pad_for_alloc(src_len)).w()?;
self.device
.htod_sync_copy_into(data.as_ref(), &mut dst.slice_mut(..src_len))
.w()?;
self.data = dst;
Ok(())
}

View File

@ -18,7 +18,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
let actual_blocks = ys.len();
// Validate that the input is the right size
if expected_blocks != actual_blocks {
if expected_blocks > actual_blocks {
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
}