mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Integrate the MLX gemm kernels (#2468)
* Include the MLX gemm kernels. * Clippy lints. * Export the gemm_f32 kernel. * Add the f16/bf16 variants. * Add the initial dispatch code. * More plugging of the mlx kernels. * Add a currently broken test. * Tweaks. * Bugfix + get the tests to pass. * Enable the gemm bf16 tests. * Add some randomized tests. * Update candle-metal-kernels/src/lib.rs Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> * More fixes. * More clippy fixes. --------- Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:
@ -11,33 +11,35 @@ pub use utils::BufferOffset;
|
||||
use utils::{get_block_dims, linear_split, EncoderProvider};
|
||||
|
||||
const AFFINE: &str = include_str!("affine.metal");
|
||||
const INDEXING: &str = include_str!("indexing.metal");
|
||||
const UNARY: &str = include_str!("unary.metal");
|
||||
const BINARY: &str = include_str!("binary.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const INDEXING: &str = include_str!("indexing.metal");
|
||||
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const SORT: &str = include_str!("sort.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const UNARY: &str = include_str!("unary.metal");
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum Source {
|
||||
Affine,
|
||||
Indexing,
|
||||
Unary,
|
||||
Binary,
|
||||
Ternary,
|
||||
Cast,
|
||||
Reduce,
|
||||
Mfa,
|
||||
Conv,
|
||||
Random,
|
||||
Gemm,
|
||||
Indexing,
|
||||
Mfa,
|
||||
Quantized,
|
||||
Random,
|
||||
Reduce,
|
||||
Sort,
|
||||
Ternary,
|
||||
Unary,
|
||||
}
|
||||
|
||||
pub mod copy2d {
|
||||
@ -191,16 +193,17 @@ impl Kernels {
|
||||
fn get_library_source(&self, source: Source) -> &'static str {
|
||||
match source {
|
||||
Source::Affine => AFFINE,
|
||||
Source::Unary => UNARY,
|
||||
Source::Binary => BINARY,
|
||||
Source::Ternary => TERNARY,
|
||||
Source::Indexing => INDEXING,
|
||||
Source::Cast => CAST,
|
||||
Source::Reduce => REDUCE,
|
||||
Source::Conv => CONV,
|
||||
Source::Random => RANDOM,
|
||||
Source::Gemm => MLX_GEMM,
|
||||
Source::Indexing => INDEXING,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Random => RANDOM,
|
||||
Source::Reduce => REDUCE,
|
||||
Source::Sort => SORT,
|
||||
Source::Ternary => TERNARY,
|
||||
Source::Unary => UNARY,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
@ -2178,5 +2181,181 @@ pub fn call_arg_sort(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
|
||||
pub enum GemmDType {
|
||||
BF16,
|
||||
F16,
|
||||
F32,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_mlx_gemm(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
dtype: GemmDType,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
lhs_offset: usize,
|
||||
lhs_buffer: &Buffer,
|
||||
rhs_stride: &[usize],
|
||||
rhs_offset: usize,
|
||||
rhs_buffer: &Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
#[derive(Debug)]
|
||||
#[repr(C)]
|
||||
struct GemmParams {
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
lda: i32,
|
||||
ldb: i32,
|
||||
ldd: i32,
|
||||
tiles_n: i32,
|
||||
tiles_m: i32,
|
||||
batch_stride_a: isize,
|
||||
batch_stride_b: isize,
|
||||
batch_stride_d: isize,
|
||||
swizzle_log: i32,
|
||||
gemm_k_iterations_aligned: i32,
|
||||
batch_ndim: i32,
|
||||
}
|
||||
assert!(rhs_stride.len() >= 2);
|
||||
assert!(lhs_stride.len() >= 2);
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
// lhs has shape b, m, k
|
||||
// We also allow for the case where the stride on the minor dimension is not as expected but
|
||||
// there is a single element.
|
||||
let (lda, a_trans) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||
(k as i32, false)
|
||||
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
|
||||
(m as i32, true)
|
||||
} else {
|
||||
return Err(MetalKernelError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?;
|
||||
};
|
||||
// rhs has shape b, k, n
|
||||
let (ldb, b_trans) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||
(n as i32, false)
|
||||
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
|
||||
(k as i32, true)
|
||||
} else {
|
||||
return Err(MetalKernelError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?;
|
||||
};
|
||||
let (bm, bn, bk, wn, wm) = (32, 32, 16, 2, 2);
|
||||
// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/matmul.cpp#L422
|
||||
let constants = Some(ConstantValues::new(vec![
|
||||
(10, Value::Bool(/* has_batch */ b > 1)),
|
||||
(100, Value::Bool(/* use_out_source */ false)),
|
||||
(110, Value::Bool(/* do_axpby */ false)),
|
||||
(200, Value::Bool(/* align_m */ m % bm == 0)),
|
||||
(201, Value::Bool(/* align_n */ n % bn == 0)),
|
||||
(202, Value::Bool(/* align_k */ k % bk == 0)),
|
||||
(300, Value::Bool(/* do_gather */ false)),
|
||||
]));
|
||||
|
||||
let swizzle_log = 0;
|
||||
let tile = 1 << swizzle_log;
|
||||
let tn = n.div_ceil(bn);
|
||||
let tm = m.div_ceil(bm);
|
||||
let tn = tn * tile;
|
||||
let tm = tm.div_ceil(tile);
|
||||
|
||||
let batch_stride_a = if lhs_stride.len() > 2 {
|
||||
lhs_stride[lhs_stride.len() - 3]
|
||||
} else {
|
||||
m * k
|
||||
};
|
||||
let batch_stride_b = if rhs_stride.len() > 2 {
|
||||
rhs_stride[rhs_stride.len() - 3]
|
||||
} else {
|
||||
n * k
|
||||
};
|
||||
|
||||
let gemm_params = GemmParams {
|
||||
m: m as i32,
|
||||
n: n as i32,
|
||||
k: k as i32,
|
||||
lda,
|
||||
ldb,
|
||||
ldd: n as i32,
|
||||
tiles_n: tn as i32,
|
||||
tiles_m: tm as i32,
|
||||
swizzle_log,
|
||||
batch_stride_a: batch_stride_a as isize,
|
||||
batch_stride_b: batch_stride_b as isize,
|
||||
batch_stride_d: (m * n) as isize,
|
||||
batch_ndim: 1i32,
|
||||
gemm_k_iterations_aligned: (k / bk) as i32,
|
||||
};
|
||||
let batch_strides = [gemm_params.batch_stride_a, gemm_params.batch_stride_b];
|
||||
|
||||
// TODO(laurent): generate the name
|
||||
// template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]]
|
||||
let name = match (dtype, a_trans, b_trans) {
|
||||
(GemmDType::F32, false, false) => "gemm_nn_f32_f32_32_32_16_2_2",
|
||||
(GemmDType::F32, true, false) => "gemm_tn_f32_f32_32_32_16_2_2",
|
||||
(GemmDType::F32, false, true) => "gemm_nt_f32_f32_32_32_16_2_2",
|
||||
(GemmDType::F32, true, true) => "gemm_tt_f32_f32_32_32_16_2_2",
|
||||
(GemmDType::BF16, false, false) => "gemm_nn_bf16_bf16_32_32_16_2_2",
|
||||
(GemmDType::BF16, true, false) => "gemm_tn_bf16_bf16_32_32_16_2_2",
|
||||
(GemmDType::BF16, false, true) => "gemm_nt_bf16_bf16_32_32_16_2_2",
|
||||
(GemmDType::BF16, true, true) => "gemm_tt_bf16_bf16_32_32_16_2_2",
|
||||
(GemmDType::F16, false, false) => "gemm_nn_f16_f16_32_32_16_2_2",
|
||||
(GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2",
|
||||
(GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2",
|
||||
(GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2",
|
||||
};
|
||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
||||
encoder.set_buffer(3, Some(output), 0);
|
||||
encoder.set_bytes(
|
||||
4,
|
||||
std::mem::size_of::<GemmParams>() as u64,
|
||||
&gemm_params as *const GemmParams as *const c_void,
|
||||
);
|
||||
encoder.set_bytes(
|
||||
6, // batch_shape
|
||||
std::mem::size_of::<i32>() as u64,
|
||||
&(b as i32) as *const i32 as *const c_void,
|
||||
);
|
||||
encoder.set_bytes(
|
||||
7,
|
||||
(std::mem::size_of::<isize>() * batch_strides.len()) as u64,
|
||||
batch_strides.as_ptr() as *const c_void,
|
||||
);
|
||||
|
||||
let grid_size = MTLSize {
|
||||
width: tn as u64,
|
||||
height: tm as u64,
|
||||
depth: /* batch_size_out */ b as u64,
|
||||
};
|
||||
let group_size = MTLSize {
|
||||
width: 32,
|
||||
height: wn,
|
||||
depth: wm,
|
||||
};
|
||||
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
Reference in New Issue
Block a user