mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Add the MLX merge sort kernels (#2751)
* Add some metal sort kernels imported from MLX. * Add another test. * Start adding the multiblock version. * Proper kernel names. * Split out the main metal file. * Multi-block sort. * More sorting. * DType parametrization. * Add a larger test.
This commit is contained in:
@ -6,8 +6,13 @@ use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::sync::RwLock;
|
||||
|
||||
pub mod mlx_gemm;
|
||||
pub mod sort;
|
||||
pub mod utils;
|
||||
pub use utils::BufferOffset;
|
||||
|
||||
pub use mlx_gemm::{call_mlx_gemm, GemmDType};
|
||||
pub use sort::{call_arg_sort, call_mlx_arg_sort};
|
||||
use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};
|
||||
|
||||
const AFFINE: &str = include_str!("affine.metal");
|
||||
@ -17,6 +22,7 @@ const CONV: &str = include_str!("conv.metal");
|
||||
const FILL: &str = include_str!("fill.metal");
|
||||
const INDEXING: &str = include_str!("indexing.metal");
|
||||
const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
|
||||
const MLX_SORT: &str = include_str!("mlx_sort.metal");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
@ -25,6 +31,29 @@ const TERNARY: &str = include_str!("ternary.metal");
|
||||
const UNARY: &str = include_str!("unary.metal");
|
||||
const SDPA: &str = include_str!("scaled_dot_product_attention.metal");
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum DType {
|
||||
BF16,
|
||||
F16,
|
||||
F32,
|
||||
I64,
|
||||
U32,
|
||||
U8,
|
||||
}
|
||||
|
||||
impl DType {
|
||||
fn size_in_bytes(&self) -> usize {
|
||||
match self {
|
||||
Self::U8 => 1,
|
||||
Self::U32 => 4,
|
||||
Self::I64 => 8,
|
||||
Self::BF16 => 2,
|
||||
Self::F16 => 2,
|
||||
Self::F32 => 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum Source {
|
||||
Affine,
|
||||
@ -34,6 +63,7 @@ pub enum Source {
|
||||
Fill,
|
||||
Gemm,
|
||||
Indexing,
|
||||
MlxSort,
|
||||
Quantized,
|
||||
Random,
|
||||
Reduce,
|
||||
@ -257,6 +287,7 @@ impl Kernels {
|
||||
Source::Fill => FILL,
|
||||
Source::Gemm => MLX_GEMM,
|
||||
Source::Indexing => INDEXING,
|
||||
Source::MlxSort => MLX_SORT,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Random => RANDOM,
|
||||
Source::Reduce => REDUCE,
|
||||
@ -2516,219 +2547,6 @@ pub fn call_conv_transpose2d(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_arg_sort(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
ncols_pad: usize,
|
||||
src: BufferOffset,
|
||||
dst: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: nrows as u64,
|
||||
depth: 1,
|
||||
};
|
||||
let thread_group_size = MTLSize {
|
||||
width: ncols_pad as u64,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
|
||||
encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
|
||||
pub enum GemmDType {
|
||||
BF16,
|
||||
F16,
|
||||
F32,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_mlx_gemm(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
dtype: GemmDType,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
lhs_offset: usize,
|
||||
lhs_buffer: &Buffer,
|
||||
rhs_stride: &[usize],
|
||||
rhs_offset: usize,
|
||||
rhs_buffer: &Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
#[derive(Debug)]
|
||||
#[repr(C)]
|
||||
struct GemmParams {
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
lda: i32,
|
||||
ldb: i32,
|
||||
ldd: i32,
|
||||
tiles_n: i32,
|
||||
tiles_m: i32,
|
||||
batch_stride_a: isize,
|
||||
batch_stride_b: isize,
|
||||
batch_stride_d: isize,
|
||||
swizzle_log: i32,
|
||||
gemm_k_iterations_aligned: i32,
|
||||
batch_ndim: i32,
|
||||
}
|
||||
assert!(rhs_stride.len() >= 2);
|
||||
assert!(lhs_stride.len() >= 2);
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
// lhs has shape b, m, k
|
||||
// We also allow for the case where the stride on the minor dimension is not as expected but
|
||||
// there is a single element.
|
||||
let (lda, a_trans) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||
(k as i32, false)
|
||||
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
|
||||
(m as i32, true)
|
||||
} else {
|
||||
return Err(MetalKernelError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?;
|
||||
};
|
||||
// rhs has shape b, k, n
|
||||
let (ldb, b_trans) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||
(n as i32, false)
|
||||
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
|
||||
(k as i32, true)
|
||||
} else {
|
||||
return Err(MetalKernelError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?;
|
||||
};
|
||||
let (bm, bn, bk, wn, wm) = (32, 32, 16, 2, 2);
|
||||
// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/matmul.cpp#L422
|
||||
let constants = Some(ConstantValues::new(vec![
|
||||
(10, Value::Bool(/* has_batch */ b > 1)),
|
||||
(100, Value::Bool(/* use_out_source */ false)),
|
||||
(110, Value::Bool(/* do_axpby */ false)),
|
||||
(200, Value::Bool(/* align_m */ m % bm == 0)),
|
||||
(201, Value::Bool(/* align_n */ n % bn == 0)),
|
||||
(202, Value::Bool(/* align_k */ k % bk == 0)),
|
||||
(300, Value::Bool(/* do_gather */ false)),
|
||||
]));
|
||||
|
||||
let swizzle_log = 0;
|
||||
let tile = 1 << swizzle_log;
|
||||
let tn = n.div_ceil(bn);
|
||||
let tm = m.div_ceil(bm);
|
||||
let tn = tn * tile;
|
||||
let tm = tm.div_ceil(tile);
|
||||
|
||||
let batch_stride_a = if lhs_stride.len() > 2 {
|
||||
lhs_stride[lhs_stride.len() - 3]
|
||||
} else {
|
||||
m * k
|
||||
};
|
||||
let batch_stride_b = if rhs_stride.len() > 2 {
|
||||
rhs_stride[rhs_stride.len() - 3]
|
||||
} else {
|
||||
n * k
|
||||
};
|
||||
|
||||
let gemm_params = GemmParams {
|
||||
m: m as i32,
|
||||
n: n as i32,
|
||||
k: k as i32,
|
||||
lda,
|
||||
ldb,
|
||||
ldd: n as i32,
|
||||
tiles_n: tn as i32,
|
||||
tiles_m: tm as i32,
|
||||
swizzle_log,
|
||||
batch_stride_a: batch_stride_a as isize,
|
||||
batch_stride_b: batch_stride_b as isize,
|
||||
batch_stride_d: (m * n) as isize,
|
||||
batch_ndim: 1i32,
|
||||
gemm_k_iterations_aligned: (k / bk) as i32,
|
||||
};
|
||||
let batch_strides = [gemm_params.batch_stride_a, gemm_params.batch_stride_b];
|
||||
|
||||
// TODO(laurent): generate the name
|
||||
// template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]]
|
||||
let name = match (dtype, a_trans, b_trans) {
|
||||
(GemmDType::F32, false, false) => "gemm_nn_f32_f32_32_32_16_2_2",
|
||||
(GemmDType::F32, true, false) => "gemm_tn_f32_f32_32_32_16_2_2",
|
||||
(GemmDType::F32, false, true) => "gemm_nt_f32_f32_32_32_16_2_2",
|
||||
(GemmDType::F32, true, true) => "gemm_tt_f32_f32_32_32_16_2_2",
|
||||
(GemmDType::BF16, false, false) => "gemm_nn_bf16_bf16_32_32_16_2_2",
|
||||
(GemmDType::BF16, true, false) => "gemm_tn_bf16_bf16_32_32_16_2_2",
|
||||
(GemmDType::BF16, false, true) => "gemm_nt_bf16_bf16_32_32_16_2_2",
|
||||
(GemmDType::BF16, true, true) => "gemm_tt_bf16_bf16_32_32_16_2_2",
|
||||
(GemmDType::F16, false, false) => "gemm_nn_f16_f16_32_32_16_2_2",
|
||||
(GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2",
|
||||
(GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2",
|
||||
(GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2",
|
||||
};
|
||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
||||
encoder.set_buffer(3, Some(output), 0);
|
||||
encoder.set_bytes(
|
||||
4,
|
||||
std::mem::size_of::<GemmParams>() as u64,
|
||||
&gemm_params as *const GemmParams as *const c_void,
|
||||
);
|
||||
encoder.set_bytes(
|
||||
6, // batch_shape
|
||||
std::mem::size_of::<i32>() as u64,
|
||||
&(b as i32) as *const i32 as *const c_void,
|
||||
);
|
||||
encoder.set_bytes(
|
||||
7,
|
||||
(std::mem::size_of::<isize>() * batch_strides.len()) as u64,
|
||||
batch_strides.as_ptr() as *const c_void,
|
||||
);
|
||||
|
||||
let grid_size = MTLSize {
|
||||
width: tn as u64,
|
||||
height: tm as u64,
|
||||
depth: /* batch_size_out */ b as u64,
|
||||
};
|
||||
let group_size = MTLSize {
|
||||
width: 32,
|
||||
height: wn,
|
||||
depth: wm,
|
||||
};
|
||||
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call_const_fill(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
|
Reference in New Issue
Block a user