Handle the batch dimension in quantized MMV on metal. (#2022)

This commit is contained in:
Laurent Mazare
2024-04-06 20:02:24 +02:00
committed by GitHub
parent e662431acf
commit 9fd52b3b71

View File

@ -149,8 +149,11 @@ impl QMetalStorage {
let (n, k) = self_shape.dims2()?; let (n, k) = self_shape.dims2()?;
let mut dst_shape = src_shape.dims().to_vec(); let mut dst_shape = src_shape.dims().to_vec();
// We always use a single batch dimension and stack all the tensors in the batch on the
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
// properly.
let (b, m) = match dst_shape.len() { let (b, m) = match dst_shape.len() {
3 => (dst_shape[0], dst_shape[1]), 3 => (1, dst_shape[0] * dst_shape[1]),
2 => (1, dst_shape[0]), 2 => (1, dst_shape[0]),
n => crate::bail!("Invalid rank {n} for quantized matmul metal"), n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
}; };