Handle multiple dimensions in metal QMM + two fixes. (#2097)

This commit is contained in:
Laurent Mazare
2024-04-20 18:55:45 +02:00
committed by GitHub
parent 9215e9ce8c
commit dd78422701
2 changed files with 28 additions and 22 deletions

View File

@ -152,9 +152,9 @@ impl QMetalStorage {
// We always use a single batch dimension and stack all the tensors in the batch on the
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
// properly.
let (b, m) = match dst_shape.len() {
3 => (1, dst_shape[0] * dst_shape[1]),
2 => (1, dst_shape[0]),
let m = match dst_shape.len() {
3 => dst_shape[0] * dst_shape[1],
2 => dst_shape[0],
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
};
let last_k = dst_shape.pop().unwrap();
@ -166,18 +166,23 @@ impl QMetalStorage {
let device = storage.device().clone();
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
let command_buffer = device.command_buffer()?;
candle_metal_kernels::call_quantized_matmul_t(
device.device(),
&command_buffer,
device.kernels(),
self.dtype.into(),
(b, m, n, k),
storage.buffer(),
layout.start_offset() * storage.dtype().size_in_bytes(),
&self.buffer,
&dst,
)
.map_err(MetalError::from)?;
// In some cases it would be better to use the mm variant, though it has its drawbacks
// around memory alignemnt.
for batch_id in 0..m {
candle_metal_kernels::call_quantized_matmul_mv_t(
device.device(),
&command_buffer,
device.kernels(),
self.dtype.into(),
(1, 1, n, k),
storage.buffer(),
(layout.start_offset() + batch_id * k) * storage.dtype().size_in_bytes(),
&self.buffer,
batch_id * n * DType::F32.size_in_bytes(),
&dst,
)
.map_err(MetalError::from)?;
}
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
Ok((dst_storage, dst_shape))
}

View File

@ -1699,7 +1699,7 @@ pub enum GgmlDType {
}
#[allow(clippy::too_many_arguments)]
pub fn call_quantized_matmul_t(
pub fn call_quantized_matmul_mv_t(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
@ -1708,7 +1708,8 @@ pub fn call_quantized_matmul_t(
lhs: &Buffer,
lhs_offset: usize,
rhs: &Buffer,
output: &Buffer,
dst_offset: usize,
dst: &Buffer,
) -> Result<(), MetalKernelError> {
// Everything is in reverse
let ne00 = k as i64;
@ -1748,8 +1749,9 @@ pub fn call_quantized_matmul_t(
}
GgmlDType::Q2K => {
// Fixing a bug in Metal for GGML
let nth0 = 4;
let nth1 = 8;
// https://github.com/ggerganov/llama.cpp/blob/b8109bc0139f15a5b321909f47510b89dca47ffc/ggml-metal.m#L1576
let nth0 = 2;
let nth1 = 32;
let align = 4;
(nth0, nth1, align)
}
@ -1821,7 +1823,7 @@ pub fn call_quantized_matmul_t(
(
rhs,
(lhs, lhs_offset),
output,
(dst, dst_offset),
ne00,
ne01,
ne02,
@ -1840,10 +1842,9 @@ pub fn call_quantized_matmul_t(
r3
)
);
encoder.set_threadgroup_memory_length(0, 8192);
encoder.use_resource(lhs, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
encoder.end_encoding();