From 9fd52b3b71b37049d4f824379ce74b3f4d1b4eeb Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 6 Apr 2024 20:02:24 +0200 Subject: [PATCH] Handle the batch dimension in quantized MMV on metal. (#2022) --- candle-core/src/quantized/metal.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index 7be0f74e..c310d766 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -149,8 +149,11 @@ impl QMetalStorage { let (n, k) = self_shape.dims2()?; let mut dst_shape = src_shape.dims().to_vec(); + // We always use a single batch dimension and stack all the tensors in the batch on the + // second dimension as the implementation in candle-metal-kernels doesn't handle batch + // properly. let (b, m) = match dst_shape.len() { - 3 => (dst_shape[0], dst_shape[1]), + 3 => (1, dst_shape[0] * dst_shape[1]), 2 => (1, dst_shape[0]), n => crate::bail!("Invalid rank {n} for quantized matmul metal"), };