mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Handle multiple dimensions in metal QMM + two fixes. (#2097)
This commit is contained in:
@ -152,9 +152,9 @@ impl QMetalStorage {
|
||||
// We always use a single batch dimension and stack all the tensors in the batch on the
|
||||
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
|
||||
// properly.
|
||||
let (b, m) = match dst_shape.len() {
|
||||
3 => (1, dst_shape[0] * dst_shape[1]),
|
||||
2 => (1, dst_shape[0]),
|
||||
let m = match dst_shape.len() {
|
||||
3 => dst_shape[0] * dst_shape[1],
|
||||
2 => dst_shape[0],
|
||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||
};
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
@ -166,18 +166,23 @@ impl QMetalStorage {
|
||||
let device = storage.device().clone();
|
||||
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
candle_metal_kernels::call_quantized_matmul_t(
|
||||
device.device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
self.dtype.into(),
|
||||
(b, m, n, k),
|
||||
storage.buffer(),
|
||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
&self.buffer,
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
// In some cases it would be better to use the mm variant, though it has its drawbacks
|
||||
// around memory alignemnt.
|
||||
for batch_id in 0..m {
|
||||
candle_metal_kernels::call_quantized_matmul_mv_t(
|
||||
device.device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
self.dtype.into(),
|
||||
(1, 1, n, k),
|
||||
storage.buffer(),
|
||||
(layout.start_offset() + batch_id * k) * storage.dtype().size_in_bytes(),
|
||||
&self.buffer,
|
||||
batch_id * n * DType::F32.size_in_bytes(),
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
|
||||
Ok((dst_storage, dst_shape))
|
||||
}
|
||||
|
Reference in New Issue
Block a user