From df5f69444e438a7cd8d8ab4971579bf309b72114 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 10 Mar 2024 20:23:43 +0100 Subject: [PATCH] Properly handle the batch dimension in cuda quantized matmul. (#1832) --- candle-core/src/quantized/cuda.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 5b684573..c90cf576 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -313,7 +313,7 @@ impl QCudaStorage { } let data_f32 = self.dequantize(n * k)?; - let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0); + let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?; let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?; let mut out_shape = layout.shape().dims().to_vec(); out_shape.pop();