diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index eeb9a975..edc5209b 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -6,8 +6,13 @@ use std::collections::HashMap; use std::ffi::c_void; use std::sync::RwLock; +pub mod mlx_gemm; +pub mod sort; pub mod utils; pub use utils::BufferOffset; + +pub use mlx_gemm::{call_mlx_gemm, GemmDType}; +pub use sort::{call_arg_sort, call_mlx_arg_sort}; use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider}; const AFFINE: &str = include_str!("affine.metal"); @@ -17,6 +22,7 @@ const CONV: &str = include_str!("conv.metal"); const FILL: &str = include_str!("fill.metal"); const INDEXING: &str = include_str!("indexing.metal"); const MLX_GEMM: &str = include_str!("mlx_gemm.metal"); +const MLX_SORT: &str = include_str!("mlx_sort.metal"); const QUANTIZED: &str = include_str!("quantized.metal"); const RANDOM: &str = include_str!("random.metal"); const REDUCE: &str = include_str!("reduce.metal"); @@ -25,6 +31,29 @@ const TERNARY: &str = include_str!("ternary.metal"); const UNARY: &str = include_str!("unary.metal"); const SDPA: &str = include_str!("scaled_dot_product_attention.metal"); +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum DType { + BF16, + F16, + F32, + I64, + U32, + U8, +} + +impl DType { + fn size_in_bytes(&self) -> usize { + match self { + Self::U8 => 1, + Self::U32 => 4, + Self::I64 => 8, + Self::BF16 => 2, + Self::F16 => 2, + Self::F32 => 4, + } + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Source { Affine, @@ -34,6 +63,7 @@ pub enum Source { Fill, Gemm, Indexing, + MlxSort, Quantized, Random, Reduce, @@ -257,6 +287,7 @@ impl Kernels { Source::Fill => FILL, Source::Gemm => MLX_GEMM, Source::Indexing => INDEXING, + Source::MlxSort => MLX_SORT, Source::Quantized => QUANTIZED, Source::Random => RANDOM, Source::Reduce => REDUCE, @@ -2516,219 +2547,6 @@ pub fn call_conv_transpose2d( Ok(()) } -#[allow(clippy::too_many_arguments)] -pub fn call_arg_sort( - device: &Device, - ep: impl EncoderProvider, - kernels: &Kernels, - name: &'static str, - nrows: usize, - ncols: usize, - ncols_pad: usize, - src: BufferOffset, - dst: &Buffer, -) -> Result<(), MetalKernelError> { - let pipeline = kernels.load_pipeline(device, Source::Sort, name)?; - let encoder = ep.encoder(); - let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); - encoder.set_compute_pipeline_state(&pipeline); - - set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64)); - - let thread_group_count = MTLSize { - width: 1, - height: nrows as u64, - depth: 1, - }; - let thread_group_size = MTLSize { - width: ncols_pad as u64, - height: 1, - depth: 1, - }; - - encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(dst, metal::MTLResourceUsage::Write); - encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64); - encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - Ok(()) -} - -#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] -pub enum GemmDType { - BF16, - F16, - F32, -} - -#[allow(clippy::too_many_arguments)] -pub fn call_mlx_gemm( - device: &Device, - ep: impl EncoderProvider, - kernels: &Kernels, - dtype: GemmDType, - (b, m, n, k): (usize, usize, usize, usize), - lhs_stride: &[usize], - lhs_offset: usize, - lhs_buffer: &Buffer, - rhs_stride: &[usize], - rhs_offset: usize, - rhs_buffer: &Buffer, - output: &Buffer, -) -> Result<(), MetalKernelError> { - #[derive(Debug)] - #[repr(C)] - struct GemmParams { - m: i32, - n: i32, - k: i32, - lda: i32, - ldb: i32, - ldd: i32, - tiles_n: i32, - tiles_m: i32, - batch_stride_a: isize, - batch_stride_b: isize, - batch_stride_d: isize, - swizzle_log: i32, - gemm_k_iterations_aligned: i32, - batch_ndim: i32, - } - assert!(rhs_stride.len() >= 2); - assert!(lhs_stride.len() >= 2); - let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; - let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; - let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; - let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; - // lhs has shape b, m, k - // We also allow for the case where the stride on the minor dimension is not as expected but - // there is a single element. - let (lda, a_trans) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { - (k as i32, false) - } else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) { - (m as i32, true) - } else { - return Err(MetalKernelError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })?; - }; - // rhs has shape b, k, n - let (ldb, b_trans) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { - (n as i32, false) - } else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) { - (k as i32, true) - } else { - return Err(MetalKernelError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })?; - }; - let (bm, bn, bk, wn, wm) = (32, 32, 16, 2, 2); - // https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/matmul.cpp#L422 - let constants = Some(ConstantValues::new(vec![ - (10, Value::Bool(/* has_batch */ b > 1)), - (100, Value::Bool(/* use_out_source */ false)), - (110, Value::Bool(/* do_axpby */ false)), - (200, Value::Bool(/* align_m */ m % bm == 0)), - (201, Value::Bool(/* align_n */ n % bn == 0)), - (202, Value::Bool(/* align_k */ k % bk == 0)), - (300, Value::Bool(/* do_gather */ false)), - ])); - - let swizzle_log = 0; - let tile = 1 << swizzle_log; - let tn = n.div_ceil(bn); - let tm = m.div_ceil(bm); - let tn = tn * tile; - let tm = tm.div_ceil(tile); - - let batch_stride_a = if lhs_stride.len() > 2 { - lhs_stride[lhs_stride.len() - 3] - } else { - m * k - }; - let batch_stride_b = if rhs_stride.len() > 2 { - rhs_stride[rhs_stride.len() - 3] - } else { - n * k - }; - - let gemm_params = GemmParams { - m: m as i32, - n: n as i32, - k: k as i32, - lda, - ldb, - ldd: n as i32, - tiles_n: tn as i32, - tiles_m: tm as i32, - swizzle_log, - batch_stride_a: batch_stride_a as isize, - batch_stride_b: batch_stride_b as isize, - batch_stride_d: (m * n) as isize, - batch_ndim: 1i32, - gemm_k_iterations_aligned: (k / bk) as i32, - }; - let batch_strides = [gemm_params.batch_stride_a, gemm_params.batch_stride_b]; - - // TODO(laurent): generate the name - // template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] - let name = match (dtype, a_trans, b_trans) { - (GemmDType::F32, false, false) => "gemm_nn_f32_f32_32_32_16_2_2", - (GemmDType::F32, true, false) => "gemm_tn_f32_f32_32_32_16_2_2", - (GemmDType::F32, false, true) => "gemm_nt_f32_f32_32_32_16_2_2", - (GemmDType::F32, true, true) => "gemm_tt_f32_f32_32_32_16_2_2", - (GemmDType::BF16, false, false) => "gemm_nn_bf16_bf16_32_32_16_2_2", - (GemmDType::BF16, true, false) => "gemm_tn_bf16_bf16_32_32_16_2_2", - (GemmDType::BF16, false, true) => "gemm_nt_bf16_bf16_32_32_16_2_2", - (GemmDType::BF16, true, true) => "gemm_tt_bf16_bf16_32_32_16_2_2", - (GemmDType::F16, false, false) => "gemm_nn_f16_f16_32_32_16_2_2", - (GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2", - (GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2", - (GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2", - }; - let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?; - let encoder = ep.encoder(); - let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); - encoder.set_compute_pipeline_state(&pipeline); - encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); - encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); - encoder.set_buffer(3, Some(output), 0); - encoder.set_bytes( - 4, - std::mem::size_of::() as u64, - &gemm_params as *const GemmParams as *const c_void, - ); - encoder.set_bytes( - 6, // batch_shape - std::mem::size_of::() as u64, - &(b as i32) as *const i32 as *const c_void, - ); - encoder.set_bytes( - 7, - (std::mem::size_of::() * batch_strides.len()) as u64, - batch_strides.as_ptr() as *const c_void, - ); - - let grid_size = MTLSize { - width: tn as u64, - height: tm as u64, - depth: /* batch_size_out */ b as u64, - }; - let group_size = MTLSize { - width: 32, - height: wn, - depth: wm, - }; - encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); - encoder.dispatch_thread_groups(grid_size, group_size); - Ok(()) -} - pub fn call_const_fill( device: &Device, ep: impl EncoderProvider, diff --git a/candle-metal-kernels/src/mlx_gemm.rs b/candle-metal-kernels/src/mlx_gemm.rs new file mode 100644 index 00000000..ee4292c3 --- /dev/null +++ b/candle-metal-kernels/src/mlx_gemm.rs @@ -0,0 +1,180 @@ +use crate::utils::EncoderProvider; +use crate::{ConstantValues, Kernels, MetalKernelError, Source, Value}; +use metal::{Buffer, ComputeCommandEncoderRef, Device, MTLSize, NSUInteger}; +use std::ffi::c_void; + +#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] +pub enum GemmDType { + BF16, + F16, + F32, +} + +#[allow(clippy::too_many_arguments)] +pub fn call_mlx_gemm( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + dtype: GemmDType, + (b, m, n, k): (usize, usize, usize, usize), + lhs_stride: &[usize], + lhs_offset: usize, + lhs_buffer: &Buffer, + rhs_stride: &[usize], + rhs_offset: usize, + rhs_buffer: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + #[derive(Debug)] + #[repr(C)] + struct GemmParams { + m: i32, + n: i32, + k: i32, + lda: i32, + ldb: i32, + ldd: i32, + tiles_n: i32, + tiles_m: i32, + batch_stride_a: isize, + batch_stride_b: isize, + batch_stride_d: isize, + swizzle_log: i32, + gemm_k_iterations_aligned: i32, + batch_ndim: i32, + } + assert!(rhs_stride.len() >= 2); + assert!(lhs_stride.len() >= 2); + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + // lhs has shape b, m, k + // We also allow for the case where the stride on the minor dimension is not as expected but + // there is a single element. + let (lda, a_trans) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, false) + } else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) { + (m as i32, true) + } else { + return Err(MetalKernelError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?; + }; + // rhs has shape b, k, n + let (ldb, b_trans) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, false) + } else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) { + (k as i32, true) + } else { + return Err(MetalKernelError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?; + }; + let (bm, bn, bk, wn, wm) = (32, 32, 16, 2, 2); + // https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/matmul.cpp#L422 + let constants = Some(ConstantValues::new(vec![ + (10, Value::Bool(/* has_batch */ b > 1)), + (100, Value::Bool(/* use_out_source */ false)), + (110, Value::Bool(/* do_axpby */ false)), + (200, Value::Bool(/* align_m */ m % bm == 0)), + (201, Value::Bool(/* align_n */ n % bn == 0)), + (202, Value::Bool(/* align_k */ k % bk == 0)), + (300, Value::Bool(/* do_gather */ false)), + ])); + + let swizzle_log = 0; + let tile = 1 << swizzle_log; + let tn = n.div_ceil(bn); + let tm = m.div_ceil(bm); + let tn = tn * tile; + let tm = tm.div_ceil(tile); + + let batch_stride_a = if lhs_stride.len() > 2 { + lhs_stride[lhs_stride.len() - 3] + } else { + m * k + }; + let batch_stride_b = if rhs_stride.len() > 2 { + rhs_stride[rhs_stride.len() - 3] + } else { + n * k + }; + + let gemm_params = GemmParams { + m: m as i32, + n: n as i32, + k: k as i32, + lda, + ldb, + ldd: n as i32, + tiles_n: tn as i32, + tiles_m: tm as i32, + swizzle_log, + batch_stride_a: batch_stride_a as isize, + batch_stride_b: batch_stride_b as isize, + batch_stride_d: (m * n) as isize, + batch_ndim: 1i32, + gemm_k_iterations_aligned: (k / bk) as i32, + }; + let batch_strides = [gemm_params.batch_stride_a, gemm_params.batch_stride_b]; + + // TODO(laurent): generate the name + // template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] + let name = match (dtype, a_trans, b_trans) { + (GemmDType::F32, false, false) => "gemm_nn_f32_f32_32_32_16_2_2", + (GemmDType::F32, true, false) => "gemm_tn_f32_f32_32_32_16_2_2", + (GemmDType::F32, false, true) => "gemm_nt_f32_f32_32_32_16_2_2", + (GemmDType::F32, true, true) => "gemm_tt_f32_f32_32_32_16_2_2", + (GemmDType::BF16, false, false) => "gemm_nn_bf16_bf16_32_32_16_2_2", + (GemmDType::BF16, true, false) => "gemm_tn_bf16_bf16_32_32_16_2_2", + (GemmDType::BF16, false, true) => "gemm_nt_bf16_bf16_32_32_16_2_2", + (GemmDType::BF16, true, true) => "gemm_tt_bf16_bf16_32_32_16_2_2", + (GemmDType::F16, false, false) => "gemm_nn_f16_f16_32_32_16_2_2", + (GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2", + (GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2", + (GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2", + }; + let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); + encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); + encoder.set_buffer(3, Some(output), 0); + encoder.set_bytes( + 4, + std::mem::size_of::() as u64, + &gemm_params as *const GemmParams as *const c_void, + ); + encoder.set_bytes( + 6, // batch_shape + std::mem::size_of::() as u64, + &(b as i32) as *const i32 as *const c_void, + ); + encoder.set_bytes( + 7, + (std::mem::size_of::() * batch_strides.len()) as u64, + batch_strides.as_ptr() as *const c_void, + ); + + let grid_size = MTLSize { + width: tn as u64, + height: tm as u64, + depth: /* batch_size_out */ b as u64, + }; + let group_size = MTLSize { + width: 32, + height: wn, + depth: wm, + }; + encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(grid_size, group_size); + Ok(()) +} diff --git a/candle-metal-kernels/src/mlx_sort.metal b/candle-metal-kernels/src/mlx_sort.metal new file mode 100644 index 00000000..31947545 --- /dev/null +++ b/candle-metal-kernels/src/mlx_sort.metal @@ -0,0 +1,856 @@ +// The implementation below comes from MLX. +// https://github.com/ml-explore/mlx/blob/0cea88bcc5e98e81a24d92eed8870a6976999f05/mlx/backend/metal/kernels/sort.h +// Copyright © 2023-2024 Apple Inc. + +#define MLX_MTL_CONST static constant constexpr const +#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)") + +#include +using namespace metal; +typedef bfloat bfloat16_t; + +// From utils.h +/////////////////////////////////////////////////////////////////////////////// +// Type limits utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct Limits { + static const constant U max = metal::numeric_limits::max(); + static const constant U min = metal::numeric_limits::min(); + static const constant U finite_max = metal::numeric_limits::max(); + static const constant U finite_min = metal::numeric_limits::min(); +}; + +#define instantiate_default_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = metal::numeric_limits::max(); \ + static constexpr constant type min = metal::numeric_limits::min(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + metal::numeric_limits::min(); \ + }; + +instantiate_default_limit(uint8_t); +instantiate_default_limit(uint16_t); +instantiate_default_limit(uint32_t); +instantiate_default_limit(uint64_t); +instantiate_default_limit(int8_t); +instantiate_default_limit(int16_t); +instantiate_default_limit(int32_t); +instantiate_default_limit(int64_t); + +#define instantiate_float_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = \ + metal::numeric_limits::infinity(); \ + static constexpr constant type min = \ + -metal::numeric_limits::infinity(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + -metal::numeric_limits::max(); \ + }; + +instantiate_float_limit(half); +instantiate_float_limit(float); +instantiate_float_limit(bfloat16_t); + +template <> +struct Limits { + static constexpr constant bool max = true; + static constexpr constant bool min = false; +}; + +/////////////////////////////////////////////////////////////////////////////// +// Single Array with generic dims + +template +METAL_FUNC IdxT elem_to_loc( + IdxT elem, + constant const int* shape, + constant const int64_t* strides, + int ndim) { + IdxT loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * IdxT(strides[i]); + elem /= shape[i]; + } + return loc; +} + +// Non templated version to handle arbitrary dims +template +METAL_FUNC IdxT elem_to_loc( + uint3 elem, + constant const int* shape, + constant const int64_t* strides, + int ndim) { + IdxT loc = + elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]); + for (int d = ndim - 3; d >= 0; --d) { + loc += (elem.z % shape[d]) * IdxT(strides[d]); + elem.z /= shape[d]; + } + return loc; +} + + +// Instantiate a templated kernel. +// Extra args are used as template parameters: +// e.g. instantiate_kernel(binary_int, binary, a, b) -> +// [[host_name(binary_int)]] [kernel] binary +#define instantiate_kernel(name, func, ...) \ + template [[host_name( \ + name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>; + +// Based on GPU merge sort algorithm at +// https://github.com/NVIDIA/cccl/tree/main/cub/cub + +/////////////////////////////////////////////////////////////////////////////// +// Thread-level sort +/////////////////////////////////////////////////////////////////////////////// + +template +METAL_FUNC void thread_swap(thread T& a, thread T& b) { + T w = a; + a = b; + b = w; +} + +template +struct LessThan { + static constexpr constant T init = Limits::max; + + METAL_FUNC bool operator()(T a, T b) { + return a < b; + } +}; + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short N_PER_THREAD, + typename CompareOp> +struct ThreadSort { + static METAL_FUNC void sort( + thread val_t (&vals)[N_PER_THREAD], + thread idx_t (&idxs)[N_PER_THREAD]) { + CompareOp op; + + MLX_MTL_LOOP_UNROLL + for (short i = 0; i < N_PER_THREAD; ++i) { + MLX_MTL_LOOP_UNROLL + for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) { + if (op(vals[j + 1], vals[j])) { + thread_swap(vals[j + 1], vals[j]); + thread_swap(idxs[j + 1], idxs[j]); + } + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Threadgroup-level sort +/////////////////////////////////////////////////////////////////////////////// + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp> +struct BlockMergeSort { + using thread_sort_t = + ThreadSort; + static METAL_FUNC int merge_partition( + const threadgroup val_t* As, + const threadgroup val_t* Bs, + short A_sz, + short B_sz, + short sort_md) { + CompareOp op; + + short A_st = max(0, sort_md - B_sz); + short A_ed = min(sort_md, A_sz); + + while (A_st < A_ed) { + short md = A_st + (A_ed - A_st) / 2; + auto a = As[md]; + auto b = Bs[sort_md - 1 - md]; + + if (op(b, a)) { + A_ed = md; + } else { + A_st = md + 1; + } + } + + return A_ed; + } + + static METAL_FUNC void merge_step( + const threadgroup val_t* As, + const threadgroup val_t* Bs, + const threadgroup idx_t* As_idx, + const threadgroup idx_t* Bs_idx, + short A_sz, + short B_sz, + thread val_t (&vals)[N_PER_THREAD], + thread idx_t (&idxs)[N_PER_THREAD]) { + CompareOp op; + short a_idx = 0; + short b_idx = 0; + + for (int i = 0; i < N_PER_THREAD; ++i) { + auto a = As[a_idx]; + auto b = Bs[b_idx]; + bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a)); + + vals[i] = pred ? b : a; + idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx]; + + b_idx += short(pred); + a_idx += short(!pred); + } + } + + static METAL_FUNC void sort( + threadgroup val_t* tgp_vals [[threadgroup(0)]], + threadgroup idx_t* tgp_idxs [[threadgroup(1)]], + int size_sorted_axis, + uint3 lid [[thread_position_in_threadgroup]]) { + // Get thread location + int idx = lid.x * N_PER_THREAD; + + // Load from shared memory + thread val_t thread_vals[N_PER_THREAD]; + thread idx_t thread_idxs[N_PER_THREAD]; + for (int i = 0; i < N_PER_THREAD; ++i) { + thread_vals[i] = tgp_vals[idx + i]; + if (ARG_SORT) { + thread_idxs[i] = tgp_idxs[idx + i]; + } + } + + // Per thread sort + if (idx < size_sorted_axis) { + thread_sort_t::sort(thread_vals, thread_idxs); + } + + // Do merges using threadgroup memory + for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; + merge_threads *= 2) { + // Update threadgroup memory + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; ++i) { + tgp_vals[idx + i] = thread_vals[i]; + if (ARG_SORT) { + tgp_idxs[idx + i] = thread_idxs[i]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Find location in merge step + int merge_group = lid.x / merge_threads; + int merge_lane = lid.x % merge_threads; + + int sort_sz = N_PER_THREAD * merge_threads; + int sort_st = N_PER_THREAD * merge_threads * merge_group; + + // As = tgp_vals[A_st:A_ed] is sorted + // Bs = tgp_vals[B_st:B_ed] is sorted + int A_st = sort_st; + int A_ed = sort_st + sort_sz / 2; + int B_st = sort_st + sort_sz / 2; + int B_ed = sort_st + sort_sz; + + const threadgroup val_t* As = tgp_vals + A_st; + const threadgroup val_t* Bs = tgp_vals + B_st; + int A_sz = A_ed - A_st; + int B_sz = B_ed - B_st; + + // Find a partition of merge elements + // Ci = merge(As[partition:], Bs[sort_md - partition:]) + // of size N_PER_THREAD for each merge lane i + // C = [Ci] is sorted + int sort_md = N_PER_THREAD * merge_lane; + int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md); + + As += partition; + Bs += sort_md - partition; + + A_sz -= partition; + B_sz -= sort_md - partition; + + const threadgroup idx_t* As_idx = + ARG_SORT ? tgp_idxs + A_st + partition : nullptr; + const threadgroup idx_t* Bs_idx = + ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr; + + // Merge starting at the partition and store results in thread registers + merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs); + } + + // Write out to shared memory + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; ++i) { + tgp_vals[idx + i] = thread_vals[i]; + if (ARG_SORT) { + tgp_idxs[idx + i] = thread_idxs[i]; + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Kernel sort +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename U, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp = LessThan> +struct KernelMergeSort { + using val_t = T; + using idx_t = uint; + using block_merge_sort_t = BlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, + CompareOp>; + + MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; + + static METAL_FUNC void block_sort( + const device T* inp, + device U* out, + const constant int& size_sorted_axis, + const constant int& in_stride_sorted_axis, + const constant int& out_stride_sorted_axis, + const constant int& in_stride_segment_axis, + const constant int& out_stride_segment_axis, + threadgroup val_t* tgp_vals, + threadgroup idx_t* tgp_idxs, + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // tid.y tells us the segment index + inp += tid.y * in_stride_segment_axis; + out += tid.y * out_stride_segment_axis; + + // Copy into threadgroup memory + for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { + tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis] + : val_t(CompareOp::init); + if (ARG_SORT) { + tgp_idxs[i] = i; + } + } + + // Sort elements within the block + threadgroup_barrier(mem_flags::mem_threadgroup); + + block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write output + for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) { + if (ARG_SORT) { + out[i * out_stride_sorted_axis] = tgp_idxs[i]; + } else { + out[i * out_stride_sorted_axis] = tgp_vals[i]; + } + } + } +}; + +template < + typename T, + typename U, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort( + const device T* inp [[buffer(0)]], + device U* out [[buffer(1)]], + const constant int& size_sorted_axis [[buffer(2)]], + const constant int& in_stride_sorted_axis [[buffer(3)]], + const constant int& out_stride_sorted_axis [[buffer(4)]], + const constant int& in_stride_segment_axis [[buffer(5)]], + const constant int& out_stride_segment_axis [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = + KernelMergeSort; + using val_t = typename sort_kernel::val_t; + using idx_t = typename sort_kernel::idx_t; + + if (ARG_SORT) { + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, + tgp_vals, + tgp_idxs, + tid, + lid); + } else { + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, + tgp_vals, + nullptr, + tid, + lid); + } +} + +constant constexpr const int zero_helper = 0; + +template < + typename T, + typename U, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc( + const device T* inp [[buffer(0)]], + device U* out [[buffer(1)]], + const constant int& size_sorted_axis [[buffer(2)]], + const constant int& in_stride_sorted_axis [[buffer(3)]], + const constant int& out_stride_sorted_axis [[buffer(4)]], + const constant int& nc_dim [[buffer(5)]], + const constant int* nc_shape [[buffer(6)]], + const constant int64_t* in_nc_strides [[buffer(7)]], + const constant int64_t* out_nc_strides [[buffer(8)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = + KernelMergeSort; + using val_t = typename sort_kernel::val_t; + using idx_t = typename sort_kernel::idx_t; + + auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim); + auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim); + inp += in_block_idx; + out += out_block_idx; + + if (ARG_SORT) { + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + zero_helper, + zero_helper, + tgp_vals, + tgp_idxs, + tid, + lid); + } else { + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + zero_helper, + zero_helper, + tgp_vals, + nullptr, + tid, + lid); + } +} + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp = LessThan> +struct KernelMultiBlockMergeSort { + using block_merge_sort_t = BlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, + CompareOp>; + + MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; + + static METAL_FUNC void block_sort( + const device val_t* inp, + device val_t* out_vals, + device idx_t* out_idxs, + const constant int& size_sorted_axis, + const constant int& stride_sorted_axis, + threadgroup val_t* tgp_vals, + threadgroup idx_t* tgp_idxs, + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // tid.y tells us the segment index + int base_idx = tid.x * N_PER_BLOCK; + + // Copy into threadgroup memory + for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { + int idx = base_idx + i; + tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] + : val_t(CompareOp::init); + tgp_idxs[i] = idx; + } + + // Sort elements within the block + threadgroup_barrier(mem_flags::mem_threadgroup); + + block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write output + for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { + int idx = base_idx + i; + if (idx < size_sorted_axis) { + out_vals[idx] = tgp_vals[i]; + out_idxs[idx] = tgp_idxs[i]; + } + } + } + + static METAL_FUNC int merge_partition( + const device val_t* As, + const device val_t* Bs, + int A_sz, + int B_sz, + int sort_md) { + CompareOp op; + + int A_st = max(0, sort_md - B_sz); + int A_ed = min(sort_md, A_sz); + + while (A_st < A_ed) { + int md = A_st + (A_ed - A_st) / 2; + auto a = As[md]; + auto b = Bs[sort_md - 1 - md]; + + if (op(b, a)) { + A_ed = md; + } else { + A_st = md + 1; + } + } + + return A_ed; + } +}; + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort( + const device val_t* inp [[buffer(0)]], + device val_t* out_vals [[buffer(1)]], + device idx_t* out_idxs [[buffer(2)]], + const constant int& size_sorted_axis [[buffer(3)]], + const constant int& stride_sorted_axis [[buffer(4)]], + const constant int& nc_dim [[buffer(5)]], + const constant int* nc_shape [[buffer(6)]], + const constant int64_t* nc_strides [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = KernelMultiBlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD>; + + auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim); + inp += block_idx; + out_vals += tid.y * size_sorted_axis; + out_idxs += tid.y * size_sorted_axis; + + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + + sort_kernel::block_sort( + inp, + out_vals, + out_idxs, + size_sorted_axis, + stride_sorted_axis, + tgp_vals, + tgp_idxs, + tid, + lid); +} + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel]] void mb_block_partition( + device idx_t* block_partitions [[buffer(0)]], + const device val_t* dev_vals [[buffer(1)]], + const device idx_t* dev_idxs [[buffer(2)]], + const constant int& size_sorted_axis [[buffer(3)]], + const constant int& merge_tiles [[buffer(4)]], + const constant int& n_blocks [[buffer(5)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 tgp_dims [[threads_per_threadgroup]]) { + using sort_kernel = KernelMultiBlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD>; + + block_partitions += tid.y * tgp_dims.x; + dev_vals += tid.y * size_sorted_axis; + dev_idxs += tid.y * size_sorted_axis; + + for (int i = lid.x; i <= n_blocks; i += tgp_dims.x) { + // Find location in merge step + int merge_group = i / merge_tiles; + int merge_lane = i % merge_tiles; + + int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; + int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; + + int A_st = min(size_sorted_axis, sort_st); + int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); + int B_st = A_ed; + int B_ed = min(size_sorted_axis, B_st + sort_sz / 2); + + int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane); + int partition = sort_kernel::merge_partition( + dev_vals + A_st, + dev_vals + B_st, + A_ed - A_st, + B_ed - B_st, + partition_at); + + block_partitions[i] = A_st + partition; + } +} + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp = LessThan> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void +mb_block_merge( + const device idx_t* block_partitions [[buffer(0)]], + const device val_t* dev_vals_in [[buffer(1)]], + const device idx_t* dev_idxs_in [[buffer(2)]], + device val_t* dev_vals_out [[buffer(3)]], + device idx_t* dev_idxs_out [[buffer(4)]], + const constant int& size_sorted_axis [[buffer(5)]], + const constant int& merge_tiles [[buffer(6)]], + const constant int& num_tiles [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = KernelMultiBlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, + CompareOp>; + + using block_sort_t = typename sort_kernel::block_merge_sort_t; + + block_partitions += tid.y * (num_tiles + 1); + dev_vals_in += tid.y * size_sorted_axis; + dev_idxs_in += tid.y * size_sorted_axis; + dev_vals_out += tid.y * size_sorted_axis; + dev_idxs_out += tid.y * size_sorted_axis; + + int block_idx = tid.x; + int merge_group = block_idx / merge_tiles; + int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; + int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; + int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st; + + int A_st = block_partitions[block_idx + 0]; + int A_ed = block_partitions[block_idx + 1]; + int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st); + int B_ed = min( + size_sorted_axis, + 2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed); + + if ((block_idx % merge_tiles) == merge_tiles - 1) { + A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); + B_ed = min(size_sorted_axis, sort_st + sort_sz); + } + + int A_sz = A_ed - A_st; + int B_sz = B_ed - B_st; + + // Load from global memory + thread val_t thread_vals[N_PER_THREAD]; + thread idx_t thread_idxs[N_PER_THREAD]; + for (int i = 0; i < N_PER_THREAD; i++) { + int idx = BLOCK_THREADS * i + lid.x; + if (idx < (A_sz + B_sz)) { + thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] + : dev_vals_in[B_st + idx - A_sz]; + thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] + : dev_idxs_in[B_st + idx - A_sz]; + } else { + thread_vals[i] = CompareOp::init; + thread_idxs[i] = 0; + } + } + + // Write to shared memory + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; i++) { + int idx = BLOCK_THREADS * i + lid.x; + tgp_vals[idx] = thread_vals[i]; + tgp_idxs[idx] = thread_idxs[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Merge + int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x)); + + int A_st_local = block_sort_t::merge_partition( + tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local); + int A_ed_local = A_sz; + + int B_st_local = sort_md_local - A_st_local; + int B_ed_local = B_sz; + + int A_sz_local = A_ed_local - A_st_local; + int B_sz_local = B_ed_local - B_st_local; + + // Do merge + block_sort_t::merge_step( + tgp_vals + A_st_local, + tgp_vals + A_ed_local + B_st_local, + tgp_idxs + A_st_local, + tgp_idxs + A_ed_local + B_st_local, + A_sz_local, + B_sz_local, + thread_vals, + thread_idxs); + + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; ++i) { + int idx = lid.x * N_PER_THREAD; + tgp_vals[idx + i] = thread_vals[i]; + tgp_idxs[idx + i] = thread_idxs[i]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + // Write output + int base_idx = tid.x * sort_kernel::N_PER_BLOCK; + for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) { + int idx = base_idx + i; + if (idx < size_sorted_axis) { + dev_vals_out[idx] = tgp_vals[i]; + dev_idxs_out[idx] = tgp_idxs[i]; + } + } +} + +#define instantiate_block_sort( \ + name, itname, itype, otname, otype, arg_sort, bn, tn) \ + instantiate_kernel("c" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \ + block_sort, itype, otype, arg_sort, bn, tn) \ + instantiate_kernel("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \ + block_sort_nc, itype, otype, arg_sort, bn, tn) + +#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \ + instantiate_block_sort( \ + arg_block_sort, itname, itype, uint32, uint32_t, true, bn, tn) + +#define instantiate_block_sort_base(itname, itype, bn, tn) \ + instantiate_block_sort( \ + _block_sort, itname, itype, itname, itype, false, bn, tn) + +#define instantiate_block_sort_tn(itname, itype, bn) \ + instantiate_block_sort_base(itname, itype, bn, 8) \ + instantiate_arg_block_sort_base(itname, itype, bn, 8) + +#define instantiate_block_sort_bn(itname, itype) \ + instantiate_block_sort_tn(itname, itype, 128) \ + instantiate_block_sort_tn(itname, itype, 256) \ + instantiate_block_sort_tn(itname, itype, 512) + +instantiate_block_sort_bn(uint8, uint8_t) +instantiate_block_sort_bn(uint32, uint32_t) +instantiate_block_sort_bn(float16, half) +instantiate_block_sort_bn(float32, float) +instantiate_block_sort_bn(bfloat16, bfloat16_t) + +#define instantiate_block_sort_long(itname, itype) \ + instantiate_block_sort_tn(itname, itype, 128) \ + instantiate_block_sort_tn(itname, itype, 256) + +instantiate_block_sort_long(int64, int64_t) + +#define instantiate_multi_block_sort( \ + vtname, vtype, itname, itype, arg_sort, bn, tn) \ + instantiate_kernel("sort_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \ + mb_block_sort, vtype, itype, arg_sort, bn, tn) \ + instantiate_kernel("partition_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \ + mb_block_partition, vtype, itype, arg_sort, bn, tn) \ + instantiate_kernel("merge_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \ + mb_block_merge, vtype, itype, arg_sort, bn, tn) + +#define instantiate_multi_block_sort_base(vtname, vtype) \ + instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8) + +instantiate_multi_block_sort_base(uint8, uint8_t) +instantiate_multi_block_sort_base(uint32, uint32_t) +instantiate_multi_block_sort_base(float16, half) +instantiate_multi_block_sort_base(float32, float) +instantiate_multi_block_sort_base(bfloat16, bfloat16_t) + +#define instantiate_multi_block_sort_long(vtname, vtype) \ + instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8) + +instantiate_multi_block_sort_long(int64, int64_t) // clang-format on diff --git a/candle-metal-kernels/src/sort.rs b/candle-metal-kernels/src/sort.rs new file mode 100644 index 00000000..e4140eb3 --- /dev/null +++ b/candle-metal-kernels/src/sort.rs @@ -0,0 +1,296 @@ +use crate::utils::{BufferOffset, EncoderProvider}; +use crate::{set_params, DType, Kernels, MetalKernelError, Source}; +use metal::{Buffer, ComputeCommandEncoderRef, Device, MTLResourceOptions, MTLSize}; + +#[allow(clippy::too_many_arguments)] +pub fn call_arg_sort( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + nrows: usize, + ncols: usize, + ncols_pad: usize, + src: BufferOffset, + dst: &Buffer, +) -> Result<(), crate::MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Sort, name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64)); + + let thread_group_count = MTLSize { + width: 1, + height: nrows as u64, + depth: 1, + }; + let thread_group_size = MTLSize { + width: ncols_pad as u64, + height: 1, + depth: 1, + }; + + encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(dst, metal::MTLResourceUsage::Write); + encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +fn mlx_dtype_str(dtype: DType) -> &'static str { + match dtype { + DType::U8 => "uint8", + DType::U32 => "uint32", + DType::I64 => "int64", + DType::F16 => "float16", + DType::BF16 => "bfloat16", + DType::F32 => "float32", + } +} + +#[allow(clippy::too_many_arguments)] +pub fn multi_block_sort( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + dtype: DType, + bn: usize, + tn: usize, + nblocks: usize, + nrows: usize, + ncols: usize, + src: BufferOffset, + dst: &Buffer, +) -> Result<(), MetalKernelError> { + let dtype_str = mlx_dtype_str(dtype); + // Do allocations + let el_count = nrows * ncols; + let bytes_len = (el_count * dtype.size_in_bytes()) as u64; + let mut dev_vals_0 = device.new_buffer(bytes_len, MTLResourceOptions::StorageModePrivate); + let mut dev_vals_1 = device.new_buffer(bytes_len, MTLResourceOptions::StorageModePrivate); + let mut dev_idxs_0 = + device.new_buffer(el_count as u64 * 4, MTLResourceOptions::StorageModePrivate); + let mut dev_idxs_1 = + device.new_buffer(el_count as u64 * 4, MTLResourceOptions::StorageModePrivate); + let mut block_partitions = device.new_buffer( + (nrows * (nblocks + 1)) as u64 * 4, + MTLResourceOptions::StorageModePrivate, + ); + // Prepare command encoder + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + // Do blockwise sort + { + let name = format!("sort_mbsort_{dtype_str}_uint32_bn{bn}_tn{tn}"); + let pipeline = kernels.load_pipeline(device, Source::MlxSort, name)?; + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + &src, + &mut dev_vals_0, + &mut dev_idxs_0, + /* size_sorted_axis */ ncols as i32, + /* stride_sorted_axis */ 1i32, + /* nc_dim */ 1i32, + /* nc_shape */ nrows as i32, + /* nc_str */ ncols as i32 + ) + ); + let thread_group_count = MTLSize { + width: nblocks as u64, + height: nrows as u64, + depth: 1, + }; + let thread_group_size = MTLSize { + width: bn as u64, + height: 1, + depth: 1, + }; + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + } + // Do merges + let mut ping = false; + let mut merge_tiles = 2; + let n_thr_per_group = usize::min(nblocks + 1, 1024); + let partition_name = format!("partition_mbsort_{dtype_str}_uint32_bn{bn}_tn{tn}"); + let merge_name = format!("merge_mbsort_float32_uint32_bn{bn}_tn{tn}"); + while merge_tiles / 2 < nblocks { + let (dev_vals_in, dev_vals_out) = if ping { + (&mut dev_vals_1, &mut dev_vals_0) + } else { + (&mut dev_vals_0, &mut dev_vals_1) + }; + let (dev_idxs_in, dev_idxs_out) = if ping { + (&mut dev_idxs_1, &mut dev_idxs_0) + } else { + (&mut dev_idxs_0, &mut dev_idxs_1) + }; + ping = !ping; + // Do partition + { + let pipeline = + kernels.load_pipeline(device, Source::MlxSort, partition_name.clone())?; + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + &mut block_partitions, + &mut *dev_vals_in, + &mut *dev_idxs_in, + /* size_sorted_axis */ ncols as i32, + /* merge_tiles */ merge_tiles as i32, + /* n_blocks */ nblocks as i32 + ) + ); + let thread_group_count = MTLSize { + width: 1, + height: nrows as u64, + depth: 1, + }; + let thread_group_size = MTLSize { + width: n_thr_per_group as u64, + height: 1, + depth: 1, + }; + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + } + // Do merge + { + let pipeline = kernels.load_pipeline(device, Source::MlxSort, merge_name.clone())?; + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + &block_partitions, + &*dev_vals_in, + &*dev_idxs_in, + &*dev_vals_out, + &*dev_idxs_out, + /* size_sorted_axis */ ncols as i32, + /* merge_tiles */ merge_tiles as i32, + /* n_blocks */ nblocks as i32 + ) + ); + let thread_group_count = MTLSize { + width: nblocks as u64, + height: nrows as u64, + depth: 1, + }; + let thread_group_size = MTLSize { + width: bn as u64, + height: 1, + depth: 1, + }; + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + } + merge_tiles *= 2; + } + let dev_idxs_out = if ping { + &mut dev_idxs_1 + } else { + &mut dev_idxs_0 + }; + // Copy output with appropriate strides + let copy_kernel = match dtype { + DType::U8 => crate::copy2d::U8, + DType::U32 => crate::copy2d::U32, + DType::I64 => crate::copy2d::I64, + DType::BF16 => crate::copy2d::BFLOAT, + DType::F16 => crate::copy2d::HALF, + DType::F32 => crate::copy2d::FLOAT, + }; + crate::call_copy2d( + device, + encoder, + kernels, + copy_kernel, + dev_idxs_out, + dst, + /* d1 */ nrows, + /* d2 */ ncols, + /* src_s */ ncols, + /* dst_s */ ncols, + /* src_o_in_bytes */ 0, + /*dst_o_in_bytes */ 0, + )?; + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn block_sort( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + dtype: DType, + bn: usize, + tn: usize, + nrows: usize, + ncols: usize, + src: BufferOffset, + dst: &Buffer, +) -> Result<(), MetalKernelError> { + let dtype_str = mlx_dtype_str(dtype); + let name = format!("carg_block_sort_{dtype_str}_uint32_bn{bn}_tn{tn}"); + let pipeline = kernels.load_pipeline(device, Source::MlxSort, name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + &src, + dst, + ncols as i32, + 1i32, + 1i32, + ncols as i32, + ncols as i32 + ) + ); + let thread_group_count = MTLSize { + width: 1, + height: nrows as u64, + depth: 1, + }; + let thread_group_size = MTLSize { + width: bn as u64, + height: 1, + depth: 1, + }; + encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(dst, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_mlx_arg_sort( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + dtype: DType, + nrows: usize, + ncols: usize, + src: BufferOffset, + dst: &Buffer, +) -> Result<(), MetalKernelError> { + let tn = 8; + let bn = match ncols.div_ceil(tn) { + 257.. if dtype.size_in_bytes() <= 4 => 512, + 129.. => 256, + 0..129 => 128, + }; + let n_per_block = bn * tn; + let n_blocks = ncols.div_ceil(n_per_block); + if n_blocks > 1 { + multi_block_sort( + device, ep, kernels, dtype, bn, tn, n_blocks, nrows, ncols, src, dst, + )? + } else { + block_sort(device, ep, kernels, dtype, bn, tn, nrows, ncols, src, dst)? + } + Ok(()) +} diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 99e711f1..546680d4 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -605,6 +605,69 @@ fn affine_strided() { assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]); } +fn run_mlx_sort(v: &[T], ncols: usize) -> Vec { + let nrows = v.len() / ncols; + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let input = new_buffer(&device, v); + let indexes = vec![0u32; v.len()]; + let output = new_buffer(&device, &indexes); + + call_mlx_arg_sort( + &device, + command_buffer, + &kernels, + DType::F32, + nrows, + ncols, + BufferOffset::zero_offset(&input), + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + read_to_vec(&output, v.len()) +} + +#[test] +fn mlx_sort() { + use rand::SeedableRng; + use rand_distr::Distribution; + + let input: Vec<_> = (0..8).map(|v| v as f32).collect(); + let result = run_mlx_sort(&input, 4); + assert_eq!(result, [0, 1, 2, 3, 0, 1, 2, 3]); + let input: Vec<_> = (0..8).rev().map(|v| v as f32).collect(); + let result = run_mlx_sort(&input, 4); + assert_eq!(result, [3, 2, 1, 0, 3, 2, 1, 0]); + let input: Vec<_> = (0..1000).rev().map(|v| v as f32).collect(); + let result = run_mlx_sort(&input, 200); + let out: Vec<_> = (0..200).rev().collect(); + assert_eq!(&result[..200], out); + assert_eq!(&result[200..400], out); + assert_eq!(&result[400..600], out); + assert_eq!(&result[600..800], out); + assert_eq!(&result[800..], out); + + // Multi-block test + let ncols = 16000; + let mut rng = rand::rngs::StdRng::seed_from_u64(299792458); + let normal = rand_distr::Normal::new(0.0, 1.0).unwrap(); + let input: Vec = (0..ncols * 16).map(|_| normal.sample(&mut rng)).collect(); + let result = run_mlx_sort(&input, ncols); + for start in 0..16 { + let slice = &input[start * ncols..(start + 1) * ncols]; + let result = &result[start * ncols..(start + 1) * ncols]; + let mut perm: Vec = (0..ncols).collect(); + perm.sort_by(|i1, i2| slice[*i1].total_cmp(&slice[*i2])); + let perm: Vec<_> = perm.into_iter().map(|v| v as u32).collect(); + assert_eq!(perm, result); + } +} + #[test] fn index_select() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];