diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 5b684573..c90cf576 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -313,7 +313,7 @@ impl QCudaStorage { } let data_f32 = self.dequantize(n * k)?; - let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0); + let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?; let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?; let mut out_shape = layout.shape().dims().to_vec(); out_shape.pop();