mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Handle the batch dimension in quantized MMV on metal. (#2022)
This commit is contained in:
@ -149,8 +149,11 @@ impl QMetalStorage {
|
|||||||
let (n, k) = self_shape.dims2()?;
|
let (n, k) = self_shape.dims2()?;
|
||||||
let mut dst_shape = src_shape.dims().to_vec();
|
let mut dst_shape = src_shape.dims().to_vec();
|
||||||
|
|
||||||
|
// We always use a single batch dimension and stack all the tensors in the batch on the
|
||||||
|
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
|
||||||
|
// properly.
|
||||||
let (b, m) = match dst_shape.len() {
|
let (b, m) = match dst_shape.len() {
|
||||||
3 => (dst_shape[0], dst_shape[1]),
|
3 => (1, dst_shape[0] * dst_shape[1]),
|
||||||
2 => (1, dst_shape[0]),
|
2 => (1, dst_shape[0]),
|
||||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||||
};
|
};
|
||||||
|
Reference in New Issue
Block a user