From 5221146cfa507821db71f28e7e56977e108c3b46 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 25 Sep 2024 23:35:16 +0200 Subject: [PATCH] Cuda quantization padding fix. --- candle-core/src/quantized/cuda.rs | 7 +++++-- candle-core/src/quantized/utils.rs | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 8e4884b2..f59fce48 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -34,7 +34,10 @@ fn ceil_div(p: usize, q: usize) -> usize { } fn pad(p: usize, q: usize) -> usize { - ceil_div(p, q) * q + // Overallocate by q rather than just padding by q as this should pad the last row + // and we don't have enough information here to know how many elements to add :( + // ceil_div(p, q) * q + p + q } fn quantize_q8_1( @@ -439,7 +442,7 @@ impl QCudaStorage { } _ => crate::bail!("only f32 can be quantized"), }; - let src_len = src.len(); + let src_len = pad(src.len(), MATRIX_ROW_PADDING); let src = crate::Storage::Cpu(crate::CpuStorage::F32(src)); let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?; qcpu_storage.quantize(&src)?; diff --git a/candle-core/src/quantized/utils.rs b/candle-core/src/quantized/utils.rs index fa6eff51..1af21fd3 100644 --- a/candle-core/src/quantized/utils.rs +++ b/candle-core/src/quantized/utils.rs @@ -18,7 +18,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>( let actual_blocks = ys.len(); // Validate that the input is the right size - if expected_blocks != actual_blocks { + if actual_blocks < expected_blocks { crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!") }