mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Handle the batch dimension in quantized MMV on metal. (#2022)
This commit is contained in:
@ -149,8 +149,11 @@ impl QMetalStorage {
|
||||
let (n, k) = self_shape.dims2()?;
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
|
||||
// We always use a single batch dimension and stack all the tensors in the batch on the
|
||||
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
|
||||
// properly.
|
||||
let (b, m) = match dst_shape.len() {
|
||||
3 => (dst_shape[0], dst_shape[1]),
|
||||
3 => (1, dst_shape[0] * dst_shape[1]),
|
||||
2 => (1, dst_shape[0]),
|
||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||
};
|
||||
|
Reference in New Issue
Block a user