mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the MLX merge sort kernels (#2751)
* Add some metal sort kernels imported from MLX. * Add another test. * Start adding the multiblock version. * Proper kernel names. * Split out the main metal file. * Multi-block sort. * More sorting. * DType parametrization. * Add a larger test.
This commit is contained in:
@ -6,8 +6,13 @@ use std::collections::HashMap;
|
|||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::sync::RwLock;
|
use std::sync::RwLock;
|
||||||
|
|
||||||
|
pub mod mlx_gemm;
|
||||||
|
pub mod sort;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
pub use utils::BufferOffset;
|
pub use utils::BufferOffset;
|
||||||
|
|
||||||
|
pub use mlx_gemm::{call_mlx_gemm, GemmDType};
|
||||||
|
pub use sort::{call_arg_sort, call_mlx_arg_sort};
|
||||||
use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};
|
use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};
|
||||||
|
|
||||||
const AFFINE: &str = include_str!("affine.metal");
|
const AFFINE: &str = include_str!("affine.metal");
|
||||||
@ -17,6 +22,7 @@ const CONV: &str = include_str!("conv.metal");
|
|||||||
const FILL: &str = include_str!("fill.metal");
|
const FILL: &str = include_str!("fill.metal");
|
||||||
const INDEXING: &str = include_str!("indexing.metal");
|
const INDEXING: &str = include_str!("indexing.metal");
|
||||||
const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
|
const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
|
||||||
|
const MLX_SORT: &str = include_str!("mlx_sort.metal");
|
||||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||||
const RANDOM: &str = include_str!("random.metal");
|
const RANDOM: &str = include_str!("random.metal");
|
||||||
const REDUCE: &str = include_str!("reduce.metal");
|
const REDUCE: &str = include_str!("reduce.metal");
|
||||||
@ -25,6 +31,29 @@ const TERNARY: &str = include_str!("ternary.metal");
|
|||||||
const UNARY: &str = include_str!("unary.metal");
|
const UNARY: &str = include_str!("unary.metal");
|
||||||
const SDPA: &str = include_str!("scaled_dot_product_attention.metal");
|
const SDPA: &str = include_str!("scaled_dot_product_attention.metal");
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
pub enum DType {
|
||||||
|
BF16,
|
||||||
|
F16,
|
||||||
|
F32,
|
||||||
|
I64,
|
||||||
|
U32,
|
||||||
|
U8,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DType {
|
||||||
|
fn size_in_bytes(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
Self::U8 => 1,
|
||||||
|
Self::U32 => 4,
|
||||||
|
Self::I64 => 8,
|
||||||
|
Self::BF16 => 2,
|
||||||
|
Self::F16 => 2,
|
||||||
|
Self::F32 => 4,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub enum Source {
|
pub enum Source {
|
||||||
Affine,
|
Affine,
|
||||||
@ -34,6 +63,7 @@ pub enum Source {
|
|||||||
Fill,
|
Fill,
|
||||||
Gemm,
|
Gemm,
|
||||||
Indexing,
|
Indexing,
|
||||||
|
MlxSort,
|
||||||
Quantized,
|
Quantized,
|
||||||
Random,
|
Random,
|
||||||
Reduce,
|
Reduce,
|
||||||
@ -257,6 +287,7 @@ impl Kernels {
|
|||||||
Source::Fill => FILL,
|
Source::Fill => FILL,
|
||||||
Source::Gemm => MLX_GEMM,
|
Source::Gemm => MLX_GEMM,
|
||||||
Source::Indexing => INDEXING,
|
Source::Indexing => INDEXING,
|
||||||
|
Source::MlxSort => MLX_SORT,
|
||||||
Source::Quantized => QUANTIZED,
|
Source::Quantized => QUANTIZED,
|
||||||
Source::Random => RANDOM,
|
Source::Random => RANDOM,
|
||||||
Source::Reduce => REDUCE,
|
Source::Reduce => REDUCE,
|
||||||
@ -2516,219 +2547,6 @@ pub fn call_conv_transpose2d(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn call_arg_sort(
|
|
||||||
device: &Device,
|
|
||||||
ep: impl EncoderProvider,
|
|
||||||
kernels: &Kernels,
|
|
||||||
name: &'static str,
|
|
||||||
nrows: usize,
|
|
||||||
ncols: usize,
|
|
||||||
ncols_pad: usize,
|
|
||||||
src: BufferOffset,
|
|
||||||
dst: &Buffer,
|
|
||||||
) -> Result<(), MetalKernelError> {
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
|
|
||||||
let encoder = ep.encoder();
|
|
||||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
|
||||||
|
|
||||||
set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
|
|
||||||
|
|
||||||
let thread_group_count = MTLSize {
|
|
||||||
width: 1,
|
|
||||||
height: nrows as u64,
|
|
||||||
depth: 1,
|
|
||||||
};
|
|
||||||
let thread_group_size = MTLSize {
|
|
||||||
width: ncols_pad as u64,
|
|
||||||
height: 1,
|
|
||||||
depth: 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read);
|
|
||||||
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
|
|
||||||
encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
|
|
||||||
pub enum GemmDType {
|
|
||||||
BF16,
|
|
||||||
F16,
|
|
||||||
F32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn call_mlx_gemm(
|
|
||||||
device: &Device,
|
|
||||||
ep: impl EncoderProvider,
|
|
||||||
kernels: &Kernels,
|
|
||||||
dtype: GemmDType,
|
|
||||||
(b, m, n, k): (usize, usize, usize, usize),
|
|
||||||
lhs_stride: &[usize],
|
|
||||||
lhs_offset: usize,
|
|
||||||
lhs_buffer: &Buffer,
|
|
||||||
rhs_stride: &[usize],
|
|
||||||
rhs_offset: usize,
|
|
||||||
rhs_buffer: &Buffer,
|
|
||||||
output: &Buffer,
|
|
||||||
) -> Result<(), MetalKernelError> {
|
|
||||||
#[derive(Debug)]
|
|
||||||
#[repr(C)]
|
|
||||||
struct GemmParams {
|
|
||||||
m: i32,
|
|
||||||
n: i32,
|
|
||||||
k: i32,
|
|
||||||
lda: i32,
|
|
||||||
ldb: i32,
|
|
||||||
ldd: i32,
|
|
||||||
tiles_n: i32,
|
|
||||||
tiles_m: i32,
|
|
||||||
batch_stride_a: isize,
|
|
||||||
batch_stride_b: isize,
|
|
||||||
batch_stride_d: isize,
|
|
||||||
swizzle_log: i32,
|
|
||||||
gemm_k_iterations_aligned: i32,
|
|
||||||
batch_ndim: i32,
|
|
||||||
}
|
|
||||||
assert!(rhs_stride.len() >= 2);
|
|
||||||
assert!(lhs_stride.len() >= 2);
|
|
||||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
|
||||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
|
||||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
|
||||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
|
||||||
// lhs has shape b, m, k
|
|
||||||
// We also allow for the case where the stride on the minor dimension is not as expected but
|
|
||||||
// there is a single element.
|
|
||||||
let (lda, a_trans) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
|
||||||
(k as i32, false)
|
|
||||||
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
|
|
||||||
(m as i32, true)
|
|
||||||
} else {
|
|
||||||
return Err(MetalKernelError::MatMulNonContiguous {
|
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
|
||||||
mnk: (m, n, k),
|
|
||||||
})?;
|
|
||||||
};
|
|
||||||
// rhs has shape b, k, n
|
|
||||||
let (ldb, b_trans) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
|
||||||
(n as i32, false)
|
|
||||||
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
|
|
||||||
(k as i32, true)
|
|
||||||
} else {
|
|
||||||
return Err(MetalKernelError::MatMulNonContiguous {
|
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
|
||||||
mnk: (m, n, k),
|
|
||||||
})?;
|
|
||||||
};
|
|
||||||
let (bm, bn, bk, wn, wm) = (32, 32, 16, 2, 2);
|
|
||||||
// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/matmul.cpp#L422
|
|
||||||
let constants = Some(ConstantValues::new(vec![
|
|
||||||
(10, Value::Bool(/* has_batch */ b > 1)),
|
|
||||||
(100, Value::Bool(/* use_out_source */ false)),
|
|
||||||
(110, Value::Bool(/* do_axpby */ false)),
|
|
||||||
(200, Value::Bool(/* align_m */ m % bm == 0)),
|
|
||||||
(201, Value::Bool(/* align_n */ n % bn == 0)),
|
|
||||||
(202, Value::Bool(/* align_k */ k % bk == 0)),
|
|
||||||
(300, Value::Bool(/* do_gather */ false)),
|
|
||||||
]));
|
|
||||||
|
|
||||||
let swizzle_log = 0;
|
|
||||||
let tile = 1 << swizzle_log;
|
|
||||||
let tn = n.div_ceil(bn);
|
|
||||||
let tm = m.div_ceil(bm);
|
|
||||||
let tn = tn * tile;
|
|
||||||
let tm = tm.div_ceil(tile);
|
|
||||||
|
|
||||||
let batch_stride_a = if lhs_stride.len() > 2 {
|
|
||||||
lhs_stride[lhs_stride.len() - 3]
|
|
||||||
} else {
|
|
||||||
m * k
|
|
||||||
};
|
|
||||||
let batch_stride_b = if rhs_stride.len() > 2 {
|
|
||||||
rhs_stride[rhs_stride.len() - 3]
|
|
||||||
} else {
|
|
||||||
n * k
|
|
||||||
};
|
|
||||||
|
|
||||||
let gemm_params = GemmParams {
|
|
||||||
m: m as i32,
|
|
||||||
n: n as i32,
|
|
||||||
k: k as i32,
|
|
||||||
lda,
|
|
||||||
ldb,
|
|
||||||
ldd: n as i32,
|
|
||||||
tiles_n: tn as i32,
|
|
||||||
tiles_m: tm as i32,
|
|
||||||
swizzle_log,
|
|
||||||
batch_stride_a: batch_stride_a as isize,
|
|
||||||
batch_stride_b: batch_stride_b as isize,
|
|
||||||
batch_stride_d: (m * n) as isize,
|
|
||||||
batch_ndim: 1i32,
|
|
||||||
gemm_k_iterations_aligned: (k / bk) as i32,
|
|
||||||
};
|
|
||||||
let batch_strides = [gemm_params.batch_stride_a, gemm_params.batch_stride_b];
|
|
||||||
|
|
||||||
// TODO(laurent): generate the name
|
|
||||||
// template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]]
|
|
||||||
let name = match (dtype, a_trans, b_trans) {
|
|
||||||
(GemmDType::F32, false, false) => "gemm_nn_f32_f32_32_32_16_2_2",
|
|
||||||
(GemmDType::F32, true, false) => "gemm_tn_f32_f32_32_32_16_2_2",
|
|
||||||
(GemmDType::F32, false, true) => "gemm_nt_f32_f32_32_32_16_2_2",
|
|
||||||
(GemmDType::F32, true, true) => "gemm_tt_f32_f32_32_32_16_2_2",
|
|
||||||
(GemmDType::BF16, false, false) => "gemm_nn_bf16_bf16_32_32_16_2_2",
|
|
||||||
(GemmDType::BF16, true, false) => "gemm_tn_bf16_bf16_32_32_16_2_2",
|
|
||||||
(GemmDType::BF16, false, true) => "gemm_nt_bf16_bf16_32_32_16_2_2",
|
|
||||||
(GemmDType::BF16, true, true) => "gemm_tt_bf16_bf16_32_32_16_2_2",
|
|
||||||
(GemmDType::F16, false, false) => "gemm_nn_f16_f16_32_32_16_2_2",
|
|
||||||
(GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2",
|
|
||||||
(GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2",
|
|
||||||
(GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2",
|
|
||||||
};
|
|
||||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?;
|
|
||||||
let encoder = ep.encoder();
|
|
||||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
|
||||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
|
||||||
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
|
||||||
encoder.set_buffer(3, Some(output), 0);
|
|
||||||
encoder.set_bytes(
|
|
||||||
4,
|
|
||||||
std::mem::size_of::<GemmParams>() as u64,
|
|
||||||
&gemm_params as *const GemmParams as *const c_void,
|
|
||||||
);
|
|
||||||
encoder.set_bytes(
|
|
||||||
6, // batch_shape
|
|
||||||
std::mem::size_of::<i32>() as u64,
|
|
||||||
&(b as i32) as *const i32 as *const c_void,
|
|
||||||
);
|
|
||||||
encoder.set_bytes(
|
|
||||||
7,
|
|
||||||
(std::mem::size_of::<isize>() * batch_strides.len()) as u64,
|
|
||||||
batch_strides.as_ptr() as *const c_void,
|
|
||||||
);
|
|
||||||
|
|
||||||
let grid_size = MTLSize {
|
|
||||||
width: tn as u64,
|
|
||||||
height: tm as u64,
|
|
||||||
depth: /* batch_size_out */ b as u64,
|
|
||||||
};
|
|
||||||
let group_size = MTLSize {
|
|
||||||
width: 32,
|
|
||||||
height: wn,
|
|
||||||
depth: wm,
|
|
||||||
};
|
|
||||||
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
|
||||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
||||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_const_fill(
|
pub fn call_const_fill(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
ep: impl EncoderProvider,
|
ep: impl EncoderProvider,
|
||||||
|
180
candle-metal-kernels/src/mlx_gemm.rs
Normal file
180
candle-metal-kernels/src/mlx_gemm.rs
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
use crate::utils::EncoderProvider;
|
||||||
|
use crate::{ConstantValues, Kernels, MetalKernelError, Source, Value};
|
||||||
|
use metal::{Buffer, ComputeCommandEncoderRef, Device, MTLSize, NSUInteger};
|
||||||
|
use std::ffi::c_void;
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
|
||||||
|
pub enum GemmDType {
|
||||||
|
BF16,
|
||||||
|
F16,
|
||||||
|
F32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn call_mlx_gemm(
|
||||||
|
device: &Device,
|
||||||
|
ep: impl EncoderProvider,
|
||||||
|
kernels: &Kernels,
|
||||||
|
dtype: GemmDType,
|
||||||
|
(b, m, n, k): (usize, usize, usize, usize),
|
||||||
|
lhs_stride: &[usize],
|
||||||
|
lhs_offset: usize,
|
||||||
|
lhs_buffer: &Buffer,
|
||||||
|
rhs_stride: &[usize],
|
||||||
|
rhs_offset: usize,
|
||||||
|
rhs_buffer: &Buffer,
|
||||||
|
output: &Buffer,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
#[derive(Debug)]
|
||||||
|
#[repr(C)]
|
||||||
|
struct GemmParams {
|
||||||
|
m: i32,
|
||||||
|
n: i32,
|
||||||
|
k: i32,
|
||||||
|
lda: i32,
|
||||||
|
ldb: i32,
|
||||||
|
ldd: i32,
|
||||||
|
tiles_n: i32,
|
||||||
|
tiles_m: i32,
|
||||||
|
batch_stride_a: isize,
|
||||||
|
batch_stride_b: isize,
|
||||||
|
batch_stride_d: isize,
|
||||||
|
swizzle_log: i32,
|
||||||
|
gemm_k_iterations_aligned: i32,
|
||||||
|
batch_ndim: i32,
|
||||||
|
}
|
||||||
|
assert!(rhs_stride.len() >= 2);
|
||||||
|
assert!(lhs_stride.len() >= 2);
|
||||||
|
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||||
|
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||||
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
|
// lhs has shape b, m, k
|
||||||
|
// We also allow for the case where the stride on the minor dimension is not as expected but
|
||||||
|
// there is a single element.
|
||||||
|
let (lda, a_trans) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||||
|
(k as i32, false)
|
||||||
|
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
|
||||||
|
(m as i32, true)
|
||||||
|
} else {
|
||||||
|
return Err(MetalKernelError::MatMulNonContiguous {
|
||||||
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
|
mnk: (m, n, k),
|
||||||
|
})?;
|
||||||
|
};
|
||||||
|
// rhs has shape b, k, n
|
||||||
|
let (ldb, b_trans) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||||
|
(n as i32, false)
|
||||||
|
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
|
||||||
|
(k as i32, true)
|
||||||
|
} else {
|
||||||
|
return Err(MetalKernelError::MatMulNonContiguous {
|
||||||
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
|
mnk: (m, n, k),
|
||||||
|
})?;
|
||||||
|
};
|
||||||
|
let (bm, bn, bk, wn, wm) = (32, 32, 16, 2, 2);
|
||||||
|
// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/matmul.cpp#L422
|
||||||
|
let constants = Some(ConstantValues::new(vec![
|
||||||
|
(10, Value::Bool(/* has_batch */ b > 1)),
|
||||||
|
(100, Value::Bool(/* use_out_source */ false)),
|
||||||
|
(110, Value::Bool(/* do_axpby */ false)),
|
||||||
|
(200, Value::Bool(/* align_m */ m % bm == 0)),
|
||||||
|
(201, Value::Bool(/* align_n */ n % bn == 0)),
|
||||||
|
(202, Value::Bool(/* align_k */ k % bk == 0)),
|
||||||
|
(300, Value::Bool(/* do_gather */ false)),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let swizzle_log = 0;
|
||||||
|
let tile = 1 << swizzle_log;
|
||||||
|
let tn = n.div_ceil(bn);
|
||||||
|
let tm = m.div_ceil(bm);
|
||||||
|
let tn = tn * tile;
|
||||||
|
let tm = tm.div_ceil(tile);
|
||||||
|
|
||||||
|
let batch_stride_a = if lhs_stride.len() > 2 {
|
||||||
|
lhs_stride[lhs_stride.len() - 3]
|
||||||
|
} else {
|
||||||
|
m * k
|
||||||
|
};
|
||||||
|
let batch_stride_b = if rhs_stride.len() > 2 {
|
||||||
|
rhs_stride[rhs_stride.len() - 3]
|
||||||
|
} else {
|
||||||
|
n * k
|
||||||
|
};
|
||||||
|
|
||||||
|
let gemm_params = GemmParams {
|
||||||
|
m: m as i32,
|
||||||
|
n: n as i32,
|
||||||
|
k: k as i32,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
ldd: n as i32,
|
||||||
|
tiles_n: tn as i32,
|
||||||
|
tiles_m: tm as i32,
|
||||||
|
swizzle_log,
|
||||||
|
batch_stride_a: batch_stride_a as isize,
|
||||||
|
batch_stride_b: batch_stride_b as isize,
|
||||||
|
batch_stride_d: (m * n) as isize,
|
||||||
|
batch_ndim: 1i32,
|
||||||
|
gemm_k_iterations_aligned: (k / bk) as i32,
|
||||||
|
};
|
||||||
|
let batch_strides = [gemm_params.batch_stride_a, gemm_params.batch_stride_b];
|
||||||
|
|
||||||
|
// TODO(laurent): generate the name
|
||||||
|
// template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]]
|
||||||
|
let name = match (dtype, a_trans, b_trans) {
|
||||||
|
(GemmDType::F32, false, false) => "gemm_nn_f32_f32_32_32_16_2_2",
|
||||||
|
(GemmDType::F32, true, false) => "gemm_tn_f32_f32_32_32_16_2_2",
|
||||||
|
(GemmDType::F32, false, true) => "gemm_nt_f32_f32_32_32_16_2_2",
|
||||||
|
(GemmDType::F32, true, true) => "gemm_tt_f32_f32_32_32_16_2_2",
|
||||||
|
(GemmDType::BF16, false, false) => "gemm_nn_bf16_bf16_32_32_16_2_2",
|
||||||
|
(GemmDType::BF16, true, false) => "gemm_tn_bf16_bf16_32_32_16_2_2",
|
||||||
|
(GemmDType::BF16, false, true) => "gemm_nt_bf16_bf16_32_32_16_2_2",
|
||||||
|
(GemmDType::BF16, true, true) => "gemm_tt_bf16_bf16_32_32_16_2_2",
|
||||||
|
(GemmDType::F16, false, false) => "gemm_nn_f16_f16_32_32_16_2_2",
|
||||||
|
(GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2",
|
||||||
|
(GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2",
|
||||||
|
(GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2",
|
||||||
|
};
|
||||||
|
let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?;
|
||||||
|
let encoder = ep.encoder();
|
||||||
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||||
|
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
||||||
|
encoder.set_buffer(3, Some(output), 0);
|
||||||
|
encoder.set_bytes(
|
||||||
|
4,
|
||||||
|
std::mem::size_of::<GemmParams>() as u64,
|
||||||
|
&gemm_params as *const GemmParams as *const c_void,
|
||||||
|
);
|
||||||
|
encoder.set_bytes(
|
||||||
|
6, // batch_shape
|
||||||
|
std::mem::size_of::<i32>() as u64,
|
||||||
|
&(b as i32) as *const i32 as *const c_void,
|
||||||
|
);
|
||||||
|
encoder.set_bytes(
|
||||||
|
7,
|
||||||
|
(std::mem::size_of::<isize>() * batch_strides.len()) as u64,
|
||||||
|
batch_strides.as_ptr() as *const c_void,
|
||||||
|
);
|
||||||
|
|
||||||
|
let grid_size = MTLSize {
|
||||||
|
width: tn as u64,
|
||||||
|
height: tm as u64,
|
||||||
|
depth: /* batch_size_out */ b as u64,
|
||||||
|
};
|
||||||
|
let group_size = MTLSize {
|
||||||
|
width: 32,
|
||||||
|
height: wn,
|
||||||
|
depth: wm,
|
||||||
|
};
|
||||||
|
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||||
|
Ok(())
|
||||||
|
}
|
856
candle-metal-kernels/src/mlx_sort.metal
Normal file
856
candle-metal-kernels/src/mlx_sort.metal
Normal file
@ -0,0 +1,856 @@
|
|||||||
|
// The implementation below comes from MLX.
|
||||||
|
// https://github.com/ml-explore/mlx/blob/0cea88bcc5e98e81a24d92eed8870a6976999f05/mlx/backend/metal/kernels/sort.h
|
||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#define MLX_MTL_CONST static constant constexpr const
|
||||||
|
#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)")
|
||||||
|
|
||||||
|
#include <metal_stdlib>
|
||||||
|
using namespace metal;
|
||||||
|
typedef bfloat bfloat16_t;
|
||||||
|
|
||||||
|
// From utils.h
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Type limits utils
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
struct Limits {
|
||||||
|
static const constant U max = metal::numeric_limits<U>::max();
|
||||||
|
static const constant U min = metal::numeric_limits<U>::min();
|
||||||
|
static const constant U finite_max = metal::numeric_limits<U>::max();
|
||||||
|
static const constant U finite_min = metal::numeric_limits<U>::min();
|
||||||
|
};
|
||||||
|
|
||||||
|
#define instantiate_default_limit(type) \
|
||||||
|
template <> \
|
||||||
|
struct Limits<type> { \
|
||||||
|
static constexpr constant type max = metal::numeric_limits<type>::max(); \
|
||||||
|
static constexpr constant type min = metal::numeric_limits<type>::min(); \
|
||||||
|
static constexpr constant type finite_max = \
|
||||||
|
metal::numeric_limits<type>::max(); \
|
||||||
|
static constexpr constant type finite_min = \
|
||||||
|
metal::numeric_limits<type>::min(); \
|
||||||
|
};
|
||||||
|
|
||||||
|
instantiate_default_limit(uint8_t);
|
||||||
|
instantiate_default_limit(uint16_t);
|
||||||
|
instantiate_default_limit(uint32_t);
|
||||||
|
instantiate_default_limit(uint64_t);
|
||||||
|
instantiate_default_limit(int8_t);
|
||||||
|
instantiate_default_limit(int16_t);
|
||||||
|
instantiate_default_limit(int32_t);
|
||||||
|
instantiate_default_limit(int64_t);
|
||||||
|
|
||||||
|
#define instantiate_float_limit(type) \
|
||||||
|
template <> \
|
||||||
|
struct Limits<type> { \
|
||||||
|
static constexpr constant type max = \
|
||||||
|
metal::numeric_limits<type>::infinity(); \
|
||||||
|
static constexpr constant type min = \
|
||||||
|
-metal::numeric_limits<type>::infinity(); \
|
||||||
|
static constexpr constant type finite_max = \
|
||||||
|
metal::numeric_limits<type>::max(); \
|
||||||
|
static constexpr constant type finite_min = \
|
||||||
|
-metal::numeric_limits<type>::max(); \
|
||||||
|
};
|
||||||
|
|
||||||
|
instantiate_float_limit(half);
|
||||||
|
instantiate_float_limit(float);
|
||||||
|
instantiate_float_limit(bfloat16_t);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct Limits<bool> {
|
||||||
|
static constexpr constant bool max = true;
|
||||||
|
static constexpr constant bool min = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Single Array with generic dims
|
||||||
|
|
||||||
|
template <typename IdxT = int64_t>
|
||||||
|
METAL_FUNC IdxT elem_to_loc(
|
||||||
|
IdxT elem,
|
||||||
|
constant const int* shape,
|
||||||
|
constant const int64_t* strides,
|
||||||
|
int ndim) {
|
||||||
|
IdxT loc = 0;
|
||||||
|
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||||
|
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||||
|
elem /= shape[i];
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non templated version to handle arbitrary dims
|
||||||
|
template <typename IdxT = int64_t>
|
||||||
|
METAL_FUNC IdxT elem_to_loc(
|
||||||
|
uint3 elem,
|
||||||
|
constant const int* shape,
|
||||||
|
constant const int64_t* strides,
|
||||||
|
int ndim) {
|
||||||
|
IdxT loc =
|
||||||
|
elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
|
||||||
|
for (int d = ndim - 3; d >= 0; --d) {
|
||||||
|
loc += (elem.z % shape[d]) * IdxT(strides[d]);
|
||||||
|
elem.z /= shape[d];
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Instantiate a templated kernel.
|
||||||
|
// Extra args are used as template parameters:
|
||||||
|
// e.g. instantiate_kernel(binary_int, binary, a, b) ->
|
||||||
|
// [[host_name(binary_int)]] [kernel] binary<a, b>
|
||||||
|
#define instantiate_kernel(name, func, ...) \
|
||||||
|
template [[host_name( \
|
||||||
|
name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>;
|
||||||
|
|
||||||
|
// Based on GPU merge sort algorithm at
|
||||||
|
// https://github.com/NVIDIA/cccl/tree/main/cub/cub
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Thread-level sort
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
METAL_FUNC void thread_swap(thread T& a, thread T& b) {
|
||||||
|
T w = a;
|
||||||
|
a = b;
|
||||||
|
b = w;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct LessThan {
|
||||||
|
static constexpr constant T init = Limits<T>::max;
|
||||||
|
|
||||||
|
METAL_FUNC bool operator()(T a, T b) {
|
||||||
|
return a < b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename val_t,
|
||||||
|
typename idx_t,
|
||||||
|
bool ARG_SORT,
|
||||||
|
short N_PER_THREAD,
|
||||||
|
typename CompareOp>
|
||||||
|
struct ThreadSort {
|
||||||
|
static METAL_FUNC void sort(
|
||||||
|
thread val_t (&vals)[N_PER_THREAD],
|
||||||
|
thread idx_t (&idxs)[N_PER_THREAD]) {
|
||||||
|
CompareOp op;
|
||||||
|
|
||||||
|
MLX_MTL_LOOP_UNROLL
|
||||||
|
for (short i = 0; i < N_PER_THREAD; ++i) {
|
||||||
|
MLX_MTL_LOOP_UNROLL
|
||||||
|
for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
|
||||||
|
if (op(vals[j + 1], vals[j])) {
|
||||||
|
thread_swap(vals[j + 1], vals[j]);
|
||||||
|
thread_swap(idxs[j + 1], idxs[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Threadgroup-level sort
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename val_t,
|
||||||
|
typename idx_t,
|
||||||
|
bool ARG_SORT,
|
||||||
|
short BLOCK_THREADS,
|
||||||
|
short N_PER_THREAD,
|
||||||
|
typename CompareOp>
|
||||||
|
struct BlockMergeSort {
|
||||||
|
using thread_sort_t =
|
||||||
|
ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
|
||||||
|
static METAL_FUNC int merge_partition(
|
||||||
|
const threadgroup val_t* As,
|
||||||
|
const threadgroup val_t* Bs,
|
||||||
|
short A_sz,
|
||||||
|
short B_sz,
|
||||||
|
short sort_md) {
|
||||||
|
CompareOp op;
|
||||||
|
|
||||||
|
short A_st = max(0, sort_md - B_sz);
|
||||||
|
short A_ed = min(sort_md, A_sz);
|
||||||
|
|
||||||
|
while (A_st < A_ed) {
|
||||||
|
short md = A_st + (A_ed - A_st) / 2;
|
||||||
|
auto a = As[md];
|
||||||
|
auto b = Bs[sort_md - 1 - md];
|
||||||
|
|
||||||
|
if (op(b, a)) {
|
||||||
|
A_ed = md;
|
||||||
|
} else {
|
||||||
|
A_st = md + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return A_ed;
|
||||||
|
}
|
||||||
|
|
||||||
|
static METAL_FUNC void merge_step(
|
||||||
|
const threadgroup val_t* As,
|
||||||
|
const threadgroup val_t* Bs,
|
||||||
|
const threadgroup idx_t* As_idx,
|
||||||
|
const threadgroup idx_t* Bs_idx,
|
||||||
|
short A_sz,
|
||||||
|
short B_sz,
|
||||||
|
thread val_t (&vals)[N_PER_THREAD],
|
||||||
|
thread idx_t (&idxs)[N_PER_THREAD]) {
|
||||||
|
CompareOp op;
|
||||||
|
short a_idx = 0;
|
||||||
|
short b_idx = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||||
|
auto a = As[a_idx];
|
||||||
|
auto b = Bs[b_idx];
|
||||||
|
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
|
||||||
|
|
||||||
|
vals[i] = pred ? b : a;
|
||||||
|
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
|
||||||
|
|
||||||
|
b_idx += short(pred);
|
||||||
|
a_idx += short(!pred);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static METAL_FUNC void sort(
|
||||||
|
threadgroup val_t* tgp_vals [[threadgroup(0)]],
|
||||||
|
threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
|
||||||
|
int size_sorted_axis,
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
// Get thread location
|
||||||
|
int idx = lid.x * N_PER_THREAD;
|
||||||
|
|
||||||
|
// Load from shared memory
|
||||||
|
thread val_t thread_vals[N_PER_THREAD];
|
||||||
|
thread idx_t thread_idxs[N_PER_THREAD];
|
||||||
|
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||||
|
thread_vals[i] = tgp_vals[idx + i];
|
||||||
|
if (ARG_SORT) {
|
||||||
|
thread_idxs[i] = tgp_idxs[idx + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per thread sort
|
||||||
|
if (idx < size_sorted_axis) {
|
||||||
|
thread_sort_t::sort(thread_vals, thread_idxs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do merges using threadgroup memory
|
||||||
|
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS;
|
||||||
|
merge_threads *= 2) {
|
||||||
|
// Update threadgroup memory
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||||
|
tgp_vals[idx + i] = thread_vals[i];
|
||||||
|
if (ARG_SORT) {
|
||||||
|
tgp_idxs[idx + i] = thread_idxs[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Find location in merge step
|
||||||
|
int merge_group = lid.x / merge_threads;
|
||||||
|
int merge_lane = lid.x % merge_threads;
|
||||||
|
|
||||||
|
int sort_sz = N_PER_THREAD * merge_threads;
|
||||||
|
int sort_st = N_PER_THREAD * merge_threads * merge_group;
|
||||||
|
|
||||||
|
// As = tgp_vals[A_st:A_ed] is sorted
|
||||||
|
// Bs = tgp_vals[B_st:B_ed] is sorted
|
||||||
|
int A_st = sort_st;
|
||||||
|
int A_ed = sort_st + sort_sz / 2;
|
||||||
|
int B_st = sort_st + sort_sz / 2;
|
||||||
|
int B_ed = sort_st + sort_sz;
|
||||||
|
|
||||||
|
const threadgroup val_t* As = tgp_vals + A_st;
|
||||||
|
const threadgroup val_t* Bs = tgp_vals + B_st;
|
||||||
|
int A_sz = A_ed - A_st;
|
||||||
|
int B_sz = B_ed - B_st;
|
||||||
|
|
||||||
|
// Find a partition of merge elements
|
||||||
|
// Ci = merge(As[partition:], Bs[sort_md - partition:])
|
||||||
|
// of size N_PER_THREAD for each merge lane i
|
||||||
|
// C = [Ci] is sorted
|
||||||
|
int sort_md = N_PER_THREAD * merge_lane;
|
||||||
|
int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md);
|
||||||
|
|
||||||
|
As += partition;
|
||||||
|
Bs += sort_md - partition;
|
||||||
|
|
||||||
|
A_sz -= partition;
|
||||||
|
B_sz -= sort_md - partition;
|
||||||
|
|
||||||
|
const threadgroup idx_t* As_idx =
|
||||||
|
ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
|
||||||
|
const threadgroup idx_t* Bs_idx =
|
||||||
|
ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
|
||||||
|
|
||||||
|
// Merge starting at the partition and store results in thread registers
|
||||||
|
merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write out to shared memory
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||||
|
tgp_vals[idx + i] = thread_vals[i];
|
||||||
|
if (ARG_SORT) {
|
||||||
|
tgp_idxs[idx + i] = thread_idxs[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Kernel sort
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
bool ARG_SORT,
|
||||||
|
short BLOCK_THREADS,
|
||||||
|
short N_PER_THREAD,
|
||||||
|
typename CompareOp = LessThan<T>>
|
||||||
|
struct KernelMergeSort {
|
||||||
|
using val_t = T;
|
||||||
|
using idx_t = uint;
|
||||||
|
using block_merge_sort_t = BlockMergeSort<
|
||||||
|
val_t,
|
||||||
|
idx_t,
|
||||||
|
ARG_SORT,
|
||||||
|
BLOCK_THREADS,
|
||||||
|
N_PER_THREAD,
|
||||||
|
CompareOp>;
|
||||||
|
|
||||||
|
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
|
||||||
|
|
||||||
|
static METAL_FUNC void block_sort(
|
||||||
|
const device T* inp,
|
||||||
|
device U* out,
|
||||||
|
const constant int& size_sorted_axis,
|
||||||
|
const constant int& in_stride_sorted_axis,
|
||||||
|
const constant int& out_stride_sorted_axis,
|
||||||
|
const constant int& in_stride_segment_axis,
|
||||||
|
const constant int& out_stride_segment_axis,
|
||||||
|
threadgroup val_t* tgp_vals,
|
||||||
|
threadgroup idx_t* tgp_idxs,
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
// tid.y tells us the segment index
|
||||||
|
inp += tid.y * in_stride_segment_axis;
|
||||||
|
out += tid.y * out_stride_segment_axis;
|
||||||
|
|
||||||
|
// Copy into threadgroup memory
|
||||||
|
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||||
|
tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis]
|
||||||
|
: val_t(CompareOp::init);
|
||||||
|
if (ARG_SORT) {
|
||||||
|
tgp_idxs[i] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort elements within the block
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Write output
|
||||||
|
for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
|
||||||
|
if (ARG_SORT) {
|
||||||
|
out[i * out_stride_sorted_axis] = tgp_idxs[i];
|
||||||
|
} else {
|
||||||
|
out[i * out_stride_sorted_axis] = tgp_vals[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
bool ARG_SORT,
|
||||||
|
short BLOCK_THREADS,
|
||||||
|
short N_PER_THREAD>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort(
|
||||||
|
const device T* inp [[buffer(0)]],
|
||||||
|
device U* out [[buffer(1)]],
|
||||||
|
const constant int& size_sorted_axis [[buffer(2)]],
|
||||||
|
const constant int& in_stride_sorted_axis [[buffer(3)]],
|
||||||
|
const constant int& out_stride_sorted_axis [[buffer(4)]],
|
||||||
|
const constant int& in_stride_segment_axis [[buffer(5)]],
|
||||||
|
const constant int& out_stride_segment_axis [[buffer(6)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
using sort_kernel =
|
||||||
|
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||||
|
using val_t = typename sort_kernel::val_t;
|
||||||
|
using idx_t = typename sort_kernel::idx_t;
|
||||||
|
|
||||||
|
if (ARG_SORT) {
|
||||||
|
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||||
|
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||||
|
sort_kernel::block_sort(
|
||||||
|
inp,
|
||||||
|
out,
|
||||||
|
size_sorted_axis,
|
||||||
|
in_stride_sorted_axis,
|
||||||
|
out_stride_sorted_axis,
|
||||||
|
in_stride_segment_axis,
|
||||||
|
out_stride_segment_axis,
|
||||||
|
tgp_vals,
|
||||||
|
tgp_idxs,
|
||||||
|
tid,
|
||||||
|
lid);
|
||||||
|
} else {
|
||||||
|
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||||
|
sort_kernel::block_sort(
|
||||||
|
inp,
|
||||||
|
out,
|
||||||
|
size_sorted_axis,
|
||||||
|
in_stride_sorted_axis,
|
||||||
|
out_stride_sorted_axis,
|
||||||
|
in_stride_segment_axis,
|
||||||
|
out_stride_segment_axis,
|
||||||
|
tgp_vals,
|
||||||
|
nullptr,
|
||||||
|
tid,
|
||||||
|
lid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
constant constexpr const int zero_helper = 0;
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
bool ARG_SORT,
|
||||||
|
short BLOCK_THREADS,
|
||||||
|
short N_PER_THREAD>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc(
|
||||||
|
const device T* inp [[buffer(0)]],
|
||||||
|
device U* out [[buffer(1)]],
|
||||||
|
const constant int& size_sorted_axis [[buffer(2)]],
|
||||||
|
const constant int& in_stride_sorted_axis [[buffer(3)]],
|
||||||
|
const constant int& out_stride_sorted_axis [[buffer(4)]],
|
||||||
|
const constant int& nc_dim [[buffer(5)]],
|
||||||
|
const constant int* nc_shape [[buffer(6)]],
|
||||||
|
const constant int64_t* in_nc_strides [[buffer(7)]],
|
||||||
|
const constant int64_t* out_nc_strides [[buffer(8)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
using sort_kernel =
|
||||||
|
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
|
||||||
|
using val_t = typename sort_kernel::val_t;
|
||||||
|
using idx_t = typename sort_kernel::idx_t;
|
||||||
|
|
||||||
|
auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim);
|
||||||
|
auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim);
|
||||||
|
inp += in_block_idx;
|
||||||
|
out += out_block_idx;
|
||||||
|
|
||||||
|
if (ARG_SORT) {
|
||||||
|
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||||
|
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||||
|
sort_kernel::block_sort(
|
||||||
|
inp,
|
||||||
|
out,
|
||||||
|
size_sorted_axis,
|
||||||
|
in_stride_sorted_axis,
|
||||||
|
out_stride_sorted_axis,
|
||||||
|
zero_helper,
|
||||||
|
zero_helper,
|
||||||
|
tgp_vals,
|
||||||
|
tgp_idxs,
|
||||||
|
tid,
|
||||||
|
lid);
|
||||||
|
} else {
|
||||||
|
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||||
|
sort_kernel::block_sort(
|
||||||
|
inp,
|
||||||
|
out,
|
||||||
|
size_sorted_axis,
|
||||||
|
in_stride_sorted_axis,
|
||||||
|
out_stride_sorted_axis,
|
||||||
|
zero_helper,
|
||||||
|
zero_helper,
|
||||||
|
tgp_vals,
|
||||||
|
nullptr,
|
||||||
|
tid,
|
||||||
|
lid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename val_t,
|
||||||
|
typename idx_t,
|
||||||
|
bool ARG_SORT,
|
||||||
|
short BLOCK_THREADS,
|
||||||
|
short N_PER_THREAD,
|
||||||
|
typename CompareOp = LessThan<val_t>>
|
||||||
|
struct KernelMultiBlockMergeSort {
|
||||||
|
using block_merge_sort_t = BlockMergeSort<
|
||||||
|
val_t,
|
||||||
|
idx_t,
|
||||||
|
ARG_SORT,
|
||||||
|
BLOCK_THREADS,
|
||||||
|
N_PER_THREAD,
|
||||||
|
CompareOp>;
|
||||||
|
|
||||||
|
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
|
||||||
|
|
||||||
|
static METAL_FUNC void block_sort(
|
||||||
|
const device val_t* inp,
|
||||||
|
device val_t* out_vals,
|
||||||
|
device idx_t* out_idxs,
|
||||||
|
const constant int& size_sorted_axis,
|
||||||
|
const constant int& stride_sorted_axis,
|
||||||
|
threadgroup val_t* tgp_vals,
|
||||||
|
threadgroup idx_t* tgp_idxs,
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
// tid.y tells us the segment index
|
||||||
|
int base_idx = tid.x * N_PER_BLOCK;
|
||||||
|
|
||||||
|
// Copy into threadgroup memory
|
||||||
|
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||||
|
int idx = base_idx + i;
|
||||||
|
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]
|
||||||
|
: val_t(CompareOp::init);
|
||||||
|
tgp_idxs[i] = idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort elements within the block
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Write output
|
||||||
|
for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||||
|
int idx = base_idx + i;
|
||||||
|
if (idx < size_sorted_axis) {
|
||||||
|
out_vals[idx] = tgp_vals[i];
|
||||||
|
out_idxs[idx] = tgp_idxs[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static METAL_FUNC int merge_partition(
|
||||||
|
const device val_t* As,
|
||||||
|
const device val_t* Bs,
|
||||||
|
int A_sz,
|
||||||
|
int B_sz,
|
||||||
|
int sort_md) {
|
||||||
|
CompareOp op;
|
||||||
|
|
||||||
|
int A_st = max(0, sort_md - B_sz);
|
||||||
|
int A_ed = min(sort_md, A_sz);
|
||||||
|
|
||||||
|
while (A_st < A_ed) {
|
||||||
|
int md = A_st + (A_ed - A_st) / 2;
|
||||||
|
auto a = As[md];
|
||||||
|
auto b = Bs[sort_md - 1 - md];
|
||||||
|
|
||||||
|
if (op(b, a)) {
|
||||||
|
A_ed = md;
|
||||||
|
} else {
|
||||||
|
A_st = md + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return A_ed;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename val_t,
|
||||||
|
typename idx_t,
|
||||||
|
bool ARG_SORT,
|
||||||
|
short BLOCK_THREADS,
|
||||||
|
short N_PER_THREAD>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort(
|
||||||
|
const device val_t* inp [[buffer(0)]],
|
||||||
|
device val_t* out_vals [[buffer(1)]],
|
||||||
|
device idx_t* out_idxs [[buffer(2)]],
|
||||||
|
const constant int& size_sorted_axis [[buffer(3)]],
|
||||||
|
const constant int& stride_sorted_axis [[buffer(4)]],
|
||||||
|
const constant int& nc_dim [[buffer(5)]],
|
||||||
|
const constant int* nc_shape [[buffer(6)]],
|
||||||
|
const constant int64_t* nc_strides [[buffer(7)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
using sort_kernel = KernelMultiBlockMergeSort<
|
||||||
|
val_t,
|
||||||
|
idx_t,
|
||||||
|
ARG_SORT,
|
||||||
|
BLOCK_THREADS,
|
||||||
|
N_PER_THREAD>;
|
||||||
|
|
||||||
|
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
|
||||||
|
inp += block_idx;
|
||||||
|
out_vals += tid.y * size_sorted_axis;
|
||||||
|
out_idxs += tid.y * size_sorted_axis;
|
||||||
|
|
||||||
|
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||||
|
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||||
|
|
||||||
|
sort_kernel::block_sort(
|
||||||
|
inp,
|
||||||
|
out_vals,
|
||||||
|
out_idxs,
|
||||||
|
size_sorted_axis,
|
||||||
|
stride_sorted_axis,
|
||||||
|
tgp_vals,
|
||||||
|
tgp_idxs,
|
||||||
|
tid,
|
||||||
|
lid);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename val_t,
|
||||||
|
typename idx_t,
|
||||||
|
bool ARG_SORT,
|
||||||
|
short BLOCK_THREADS,
|
||||||
|
short N_PER_THREAD>
|
||||||
|
[[kernel]] void mb_block_partition(
|
||||||
|
device idx_t* block_partitions [[buffer(0)]],
|
||||||
|
const device val_t* dev_vals [[buffer(1)]],
|
||||||
|
const device idx_t* dev_idxs [[buffer(2)]],
|
||||||
|
const constant int& size_sorted_axis [[buffer(3)]],
|
||||||
|
const constant int& merge_tiles [[buffer(4)]],
|
||||||
|
const constant int& n_blocks [[buffer(5)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint3 tgp_dims [[threads_per_threadgroup]]) {
|
||||||
|
using sort_kernel = KernelMultiBlockMergeSort<
|
||||||
|
val_t,
|
||||||
|
idx_t,
|
||||||
|
ARG_SORT,
|
||||||
|
BLOCK_THREADS,
|
||||||
|
N_PER_THREAD>;
|
||||||
|
|
||||||
|
block_partitions += tid.y * tgp_dims.x;
|
||||||
|
dev_vals += tid.y * size_sorted_axis;
|
||||||
|
dev_idxs += tid.y * size_sorted_axis;
|
||||||
|
|
||||||
|
for (int i = lid.x; i <= n_blocks; i += tgp_dims.x) {
|
||||||
|
// Find location in merge step
|
||||||
|
int merge_group = i / merge_tiles;
|
||||||
|
int merge_lane = i % merge_tiles;
|
||||||
|
|
||||||
|
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
|
||||||
|
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
|
||||||
|
|
||||||
|
int A_st = min(size_sorted_axis, sort_st);
|
||||||
|
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
|
||||||
|
int B_st = A_ed;
|
||||||
|
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
|
||||||
|
|
||||||
|
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
|
||||||
|
int partition = sort_kernel::merge_partition(
|
||||||
|
dev_vals + A_st,
|
||||||
|
dev_vals + B_st,
|
||||||
|
A_ed - A_st,
|
||||||
|
B_ed - B_st,
|
||||||
|
partition_at);
|
||||||
|
|
||||||
|
block_partitions[i] = A_st + partition;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename val_t,
|
||||||
|
typename idx_t,
|
||||||
|
bool ARG_SORT,
|
||||||
|
short BLOCK_THREADS,
|
||||||
|
short N_PER_THREAD,
|
||||||
|
typename CompareOp = LessThan<val_t>>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
|
||||||
|
mb_block_merge(
|
||||||
|
const device idx_t* block_partitions [[buffer(0)]],
|
||||||
|
const device val_t* dev_vals_in [[buffer(1)]],
|
||||||
|
const device idx_t* dev_idxs_in [[buffer(2)]],
|
||||||
|
device val_t* dev_vals_out [[buffer(3)]],
|
||||||
|
device idx_t* dev_idxs_out [[buffer(4)]],
|
||||||
|
const constant int& size_sorted_axis [[buffer(5)]],
|
||||||
|
const constant int& merge_tiles [[buffer(6)]],
|
||||||
|
const constant int& num_tiles [[buffer(7)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
using sort_kernel = KernelMultiBlockMergeSort<
|
||||||
|
val_t,
|
||||||
|
idx_t,
|
||||||
|
ARG_SORT,
|
||||||
|
BLOCK_THREADS,
|
||||||
|
N_PER_THREAD,
|
||||||
|
CompareOp>;
|
||||||
|
|
||||||
|
using block_sort_t = typename sort_kernel::block_merge_sort_t;
|
||||||
|
|
||||||
|
block_partitions += tid.y * (num_tiles + 1);
|
||||||
|
dev_vals_in += tid.y * size_sorted_axis;
|
||||||
|
dev_idxs_in += tid.y * size_sorted_axis;
|
||||||
|
dev_vals_out += tid.y * size_sorted_axis;
|
||||||
|
dev_idxs_out += tid.y * size_sorted_axis;
|
||||||
|
|
||||||
|
int block_idx = tid.x;
|
||||||
|
int merge_group = block_idx / merge_tiles;
|
||||||
|
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
|
||||||
|
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
|
||||||
|
int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
|
||||||
|
|
||||||
|
int A_st = block_partitions[block_idx + 0];
|
||||||
|
int A_ed = block_partitions[block_idx + 1];
|
||||||
|
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
|
||||||
|
int B_ed = min(
|
||||||
|
size_sorted_axis,
|
||||||
|
2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
|
||||||
|
|
||||||
|
if ((block_idx % merge_tiles) == merge_tiles - 1) {
|
||||||
|
A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
|
||||||
|
B_ed = min(size_sorted_axis, sort_st + sort_sz);
|
||||||
|
}
|
||||||
|
|
||||||
|
int A_sz = A_ed - A_st;
|
||||||
|
int B_sz = B_ed - B_st;
|
||||||
|
|
||||||
|
// Load from global memory
|
||||||
|
thread val_t thread_vals[N_PER_THREAD];
|
||||||
|
thread idx_t thread_idxs[N_PER_THREAD];
|
||||||
|
for (int i = 0; i < N_PER_THREAD; i++) {
|
||||||
|
int idx = BLOCK_THREADS * i + lid.x;
|
||||||
|
if (idx < (A_sz + B_sz)) {
|
||||||
|
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
|
||||||
|
: dev_vals_in[B_st + idx - A_sz];
|
||||||
|
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
|
||||||
|
: dev_idxs_in[B_st + idx - A_sz];
|
||||||
|
} else {
|
||||||
|
thread_vals[i] = CompareOp::init;
|
||||||
|
thread_idxs[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write to shared memory
|
||||||
|
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||||
|
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
for (int i = 0; i < N_PER_THREAD; i++) {
|
||||||
|
int idx = BLOCK_THREADS * i + lid.x;
|
||||||
|
tgp_vals[idx] = thread_vals[i];
|
||||||
|
tgp_idxs[idx] = thread_idxs[i];
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Merge
|
||||||
|
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
|
||||||
|
|
||||||
|
int A_st_local = block_sort_t::merge_partition(
|
||||||
|
tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
|
||||||
|
int A_ed_local = A_sz;
|
||||||
|
|
||||||
|
int B_st_local = sort_md_local - A_st_local;
|
||||||
|
int B_ed_local = B_sz;
|
||||||
|
|
||||||
|
int A_sz_local = A_ed_local - A_st_local;
|
||||||
|
int B_sz_local = B_ed_local - B_st_local;
|
||||||
|
|
||||||
|
// Do merge
|
||||||
|
block_sort_t::merge_step(
|
||||||
|
tgp_vals + A_st_local,
|
||||||
|
tgp_vals + A_ed_local + B_st_local,
|
||||||
|
tgp_idxs + A_st_local,
|
||||||
|
tgp_idxs + A_ed_local + B_st_local,
|
||||||
|
A_sz_local,
|
||||||
|
B_sz_local,
|
||||||
|
thread_vals,
|
||||||
|
thread_idxs);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
for (int i = 0; i < N_PER_THREAD; ++i) {
|
||||||
|
int idx = lid.x * N_PER_THREAD;
|
||||||
|
tgp_vals[idx + i] = thread_vals[i];
|
||||||
|
tgp_idxs[idx + i] = thread_idxs[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Write output
|
||||||
|
int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
|
||||||
|
for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||||
|
int idx = base_idx + i;
|
||||||
|
if (idx < size_sorted_axis) {
|
||||||
|
dev_vals_out[idx] = tgp_vals[i];
|
||||||
|
dev_idxs_out[idx] = tgp_idxs[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_block_sort( \
|
||||||
|
name, itname, itype, otname, otype, arg_sort, bn, tn) \
|
||||||
|
instantiate_kernel("c" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \
|
||||||
|
block_sort, itype, otype, arg_sort, bn, tn) \
|
||||||
|
instantiate_kernel("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \
|
||||||
|
block_sort_nc, itype, otype, arg_sort, bn, tn)
|
||||||
|
|
||||||
|
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
|
||||||
|
instantiate_block_sort( \
|
||||||
|
arg_block_sort, itname, itype, uint32, uint32_t, true, bn, tn)
|
||||||
|
|
||||||
|
#define instantiate_block_sort_base(itname, itype, bn, tn) \
|
||||||
|
instantiate_block_sort( \
|
||||||
|
_block_sort, itname, itype, itname, itype, false, bn, tn)
|
||||||
|
|
||||||
|
#define instantiate_block_sort_tn(itname, itype, bn) \
|
||||||
|
instantiate_block_sort_base(itname, itype, bn, 8) \
|
||||||
|
instantiate_arg_block_sort_base(itname, itype, bn, 8)
|
||||||
|
|
||||||
|
#define instantiate_block_sort_bn(itname, itype) \
|
||||||
|
instantiate_block_sort_tn(itname, itype, 128) \
|
||||||
|
instantiate_block_sort_tn(itname, itype, 256) \
|
||||||
|
instantiate_block_sort_tn(itname, itype, 512)
|
||||||
|
|
||||||
|
instantiate_block_sort_bn(uint8, uint8_t)
|
||||||
|
instantiate_block_sort_bn(uint32, uint32_t)
|
||||||
|
instantiate_block_sort_bn(float16, half)
|
||||||
|
instantiate_block_sort_bn(float32, float)
|
||||||
|
instantiate_block_sort_bn(bfloat16, bfloat16_t)
|
||||||
|
|
||||||
|
#define instantiate_block_sort_long(itname, itype) \
|
||||||
|
instantiate_block_sort_tn(itname, itype, 128) \
|
||||||
|
instantiate_block_sort_tn(itname, itype, 256)
|
||||||
|
|
||||||
|
instantiate_block_sort_long(int64, int64_t)
|
||||||
|
|
||||||
|
#define instantiate_multi_block_sort( \
|
||||||
|
vtname, vtype, itname, itype, arg_sort, bn, tn) \
|
||||||
|
instantiate_kernel("sort_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
|
||||||
|
mb_block_sort, vtype, itype, arg_sort, bn, tn) \
|
||||||
|
instantiate_kernel("partition_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
|
||||||
|
mb_block_partition, vtype, itype, arg_sort, bn, tn) \
|
||||||
|
instantiate_kernel("merge_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
|
||||||
|
mb_block_merge, vtype, itype, arg_sort, bn, tn)
|
||||||
|
|
||||||
|
#define instantiate_multi_block_sort_base(vtname, vtype) \
|
||||||
|
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)
|
||||||
|
|
||||||
|
instantiate_multi_block_sort_base(uint8, uint8_t)
|
||||||
|
instantiate_multi_block_sort_base(uint32, uint32_t)
|
||||||
|
instantiate_multi_block_sort_base(float16, half)
|
||||||
|
instantiate_multi_block_sort_base(float32, float)
|
||||||
|
instantiate_multi_block_sort_base(bfloat16, bfloat16_t)
|
||||||
|
|
||||||
|
#define instantiate_multi_block_sort_long(vtname, vtype) \
|
||||||
|
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8)
|
||||||
|
|
||||||
|
instantiate_multi_block_sort_long(int64, int64_t) // clang-format on
|
296
candle-metal-kernels/src/sort.rs
Normal file
296
candle-metal-kernels/src/sort.rs
Normal file
@ -0,0 +1,296 @@
|
|||||||
|
use crate::utils::{BufferOffset, EncoderProvider};
|
||||||
|
use crate::{set_params, DType, Kernels, MetalKernelError, Source};
|
||||||
|
use metal::{Buffer, ComputeCommandEncoderRef, Device, MTLResourceOptions, MTLSize};
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn call_arg_sort(
|
||||||
|
device: &Device,
|
||||||
|
ep: impl EncoderProvider,
|
||||||
|
kernels: &Kernels,
|
||||||
|
name: &'static str,
|
||||||
|
nrows: usize,
|
||||||
|
ncols: usize,
|
||||||
|
ncols_pad: usize,
|
||||||
|
src: BufferOffset,
|
||||||
|
dst: &Buffer,
|
||||||
|
) -> Result<(), crate::MetalKernelError> {
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
|
||||||
|
let encoder = ep.encoder();
|
||||||
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
|
set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
|
||||||
|
|
||||||
|
let thread_group_count = MTLSize {
|
||||||
|
width: 1,
|
||||||
|
height: nrows as u64,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
let thread_group_size = MTLSize {
|
||||||
|
width: ncols_pad as u64,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mlx_dtype_str(dtype: DType) -> &'static str {
|
||||||
|
match dtype {
|
||||||
|
DType::U8 => "uint8",
|
||||||
|
DType::U32 => "uint32",
|
||||||
|
DType::I64 => "int64",
|
||||||
|
DType::F16 => "float16",
|
||||||
|
DType::BF16 => "bfloat16",
|
||||||
|
DType::F32 => "float32",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn multi_block_sort(
|
||||||
|
device: &Device,
|
||||||
|
ep: impl EncoderProvider,
|
||||||
|
kernels: &Kernels,
|
||||||
|
dtype: DType,
|
||||||
|
bn: usize,
|
||||||
|
tn: usize,
|
||||||
|
nblocks: usize,
|
||||||
|
nrows: usize,
|
||||||
|
ncols: usize,
|
||||||
|
src: BufferOffset,
|
||||||
|
dst: &Buffer,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let dtype_str = mlx_dtype_str(dtype);
|
||||||
|
// Do allocations
|
||||||
|
let el_count = nrows * ncols;
|
||||||
|
let bytes_len = (el_count * dtype.size_in_bytes()) as u64;
|
||||||
|
let mut dev_vals_0 = device.new_buffer(bytes_len, MTLResourceOptions::StorageModePrivate);
|
||||||
|
let mut dev_vals_1 = device.new_buffer(bytes_len, MTLResourceOptions::StorageModePrivate);
|
||||||
|
let mut dev_idxs_0 =
|
||||||
|
device.new_buffer(el_count as u64 * 4, MTLResourceOptions::StorageModePrivate);
|
||||||
|
let mut dev_idxs_1 =
|
||||||
|
device.new_buffer(el_count as u64 * 4, MTLResourceOptions::StorageModePrivate);
|
||||||
|
let mut block_partitions = device.new_buffer(
|
||||||
|
(nrows * (nblocks + 1)) as u64 * 4,
|
||||||
|
MTLResourceOptions::StorageModePrivate,
|
||||||
|
);
|
||||||
|
// Prepare command encoder
|
||||||
|
let encoder = ep.encoder();
|
||||||
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
|
// Do blockwise sort
|
||||||
|
{
|
||||||
|
let name = format!("sort_mbsort_{dtype_str}_uint32_bn{bn}_tn{tn}");
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::MlxSort, name)?;
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
set_params!(
|
||||||
|
encoder,
|
||||||
|
(
|
||||||
|
&src,
|
||||||
|
&mut dev_vals_0,
|
||||||
|
&mut dev_idxs_0,
|
||||||
|
/* size_sorted_axis */ ncols as i32,
|
||||||
|
/* stride_sorted_axis */ 1i32,
|
||||||
|
/* nc_dim */ 1i32,
|
||||||
|
/* nc_shape */ nrows as i32,
|
||||||
|
/* nc_str */ ncols as i32
|
||||||
|
)
|
||||||
|
);
|
||||||
|
let thread_group_count = MTLSize {
|
||||||
|
width: nblocks as u64,
|
||||||
|
height: nrows as u64,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
let thread_group_size = MTLSize {
|
||||||
|
width: bn as u64,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
}
|
||||||
|
// Do merges
|
||||||
|
let mut ping = false;
|
||||||
|
let mut merge_tiles = 2;
|
||||||
|
let n_thr_per_group = usize::min(nblocks + 1, 1024);
|
||||||
|
let partition_name = format!("partition_mbsort_{dtype_str}_uint32_bn{bn}_tn{tn}");
|
||||||
|
let merge_name = format!("merge_mbsort_float32_uint32_bn{bn}_tn{tn}");
|
||||||
|
while merge_tiles / 2 < nblocks {
|
||||||
|
let (dev_vals_in, dev_vals_out) = if ping {
|
||||||
|
(&mut dev_vals_1, &mut dev_vals_0)
|
||||||
|
} else {
|
||||||
|
(&mut dev_vals_0, &mut dev_vals_1)
|
||||||
|
};
|
||||||
|
let (dev_idxs_in, dev_idxs_out) = if ping {
|
||||||
|
(&mut dev_idxs_1, &mut dev_idxs_0)
|
||||||
|
} else {
|
||||||
|
(&mut dev_idxs_0, &mut dev_idxs_1)
|
||||||
|
};
|
||||||
|
ping = !ping;
|
||||||
|
// Do partition
|
||||||
|
{
|
||||||
|
let pipeline =
|
||||||
|
kernels.load_pipeline(device, Source::MlxSort, partition_name.clone())?;
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
set_params!(
|
||||||
|
encoder,
|
||||||
|
(
|
||||||
|
&mut block_partitions,
|
||||||
|
&mut *dev_vals_in,
|
||||||
|
&mut *dev_idxs_in,
|
||||||
|
/* size_sorted_axis */ ncols as i32,
|
||||||
|
/* merge_tiles */ merge_tiles as i32,
|
||||||
|
/* n_blocks */ nblocks as i32
|
||||||
|
)
|
||||||
|
);
|
||||||
|
let thread_group_count = MTLSize {
|
||||||
|
width: 1,
|
||||||
|
height: nrows as u64,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
let thread_group_size = MTLSize {
|
||||||
|
width: n_thr_per_group as u64,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
}
|
||||||
|
// Do merge
|
||||||
|
{
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::MlxSort, merge_name.clone())?;
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
set_params!(
|
||||||
|
encoder,
|
||||||
|
(
|
||||||
|
&block_partitions,
|
||||||
|
&*dev_vals_in,
|
||||||
|
&*dev_idxs_in,
|
||||||
|
&*dev_vals_out,
|
||||||
|
&*dev_idxs_out,
|
||||||
|
/* size_sorted_axis */ ncols as i32,
|
||||||
|
/* merge_tiles */ merge_tiles as i32,
|
||||||
|
/* n_blocks */ nblocks as i32
|
||||||
|
)
|
||||||
|
);
|
||||||
|
let thread_group_count = MTLSize {
|
||||||
|
width: nblocks as u64,
|
||||||
|
height: nrows as u64,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
let thread_group_size = MTLSize {
|
||||||
|
width: bn as u64,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
}
|
||||||
|
merge_tiles *= 2;
|
||||||
|
}
|
||||||
|
let dev_idxs_out = if ping {
|
||||||
|
&mut dev_idxs_1
|
||||||
|
} else {
|
||||||
|
&mut dev_idxs_0
|
||||||
|
};
|
||||||
|
// Copy output with appropriate strides
|
||||||
|
let copy_kernel = match dtype {
|
||||||
|
DType::U8 => crate::copy2d::U8,
|
||||||
|
DType::U32 => crate::copy2d::U32,
|
||||||
|
DType::I64 => crate::copy2d::I64,
|
||||||
|
DType::BF16 => crate::copy2d::BFLOAT,
|
||||||
|
DType::F16 => crate::copy2d::HALF,
|
||||||
|
DType::F32 => crate::copy2d::FLOAT,
|
||||||
|
};
|
||||||
|
crate::call_copy2d(
|
||||||
|
device,
|
||||||
|
encoder,
|
||||||
|
kernels,
|
||||||
|
copy_kernel,
|
||||||
|
dev_idxs_out,
|
||||||
|
dst,
|
||||||
|
/* d1 */ nrows,
|
||||||
|
/* d2 */ ncols,
|
||||||
|
/* src_s */ ncols,
|
||||||
|
/* dst_s */ ncols,
|
||||||
|
/* src_o_in_bytes */ 0,
|
||||||
|
/*dst_o_in_bytes */ 0,
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn block_sort(
|
||||||
|
device: &Device,
|
||||||
|
ep: impl EncoderProvider,
|
||||||
|
kernels: &Kernels,
|
||||||
|
dtype: DType,
|
||||||
|
bn: usize,
|
||||||
|
tn: usize,
|
||||||
|
nrows: usize,
|
||||||
|
ncols: usize,
|
||||||
|
src: BufferOffset,
|
||||||
|
dst: &Buffer,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let dtype_str = mlx_dtype_str(dtype);
|
||||||
|
let name = format!("carg_block_sort_{dtype_str}_uint32_bn{bn}_tn{tn}");
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::MlxSort, name)?;
|
||||||
|
let encoder = ep.encoder();
|
||||||
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
set_params!(
|
||||||
|
encoder,
|
||||||
|
(
|
||||||
|
&src,
|
||||||
|
dst,
|
||||||
|
ncols as i32,
|
||||||
|
1i32,
|
||||||
|
1i32,
|
||||||
|
ncols as i32,
|
||||||
|
ncols as i32
|
||||||
|
)
|
||||||
|
);
|
||||||
|
let thread_group_count = MTLSize {
|
||||||
|
width: 1,
|
||||||
|
height: nrows as u64,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
let thread_group_size = MTLSize {
|
||||||
|
width: bn as u64,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn call_mlx_arg_sort(
|
||||||
|
device: &Device,
|
||||||
|
ep: impl EncoderProvider,
|
||||||
|
kernels: &Kernels,
|
||||||
|
dtype: DType,
|
||||||
|
nrows: usize,
|
||||||
|
ncols: usize,
|
||||||
|
src: BufferOffset,
|
||||||
|
dst: &Buffer,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let tn = 8;
|
||||||
|
let bn = match ncols.div_ceil(tn) {
|
||||||
|
257.. if dtype.size_in_bytes() <= 4 => 512,
|
||||||
|
129.. => 256,
|
||||||
|
0..129 => 128,
|
||||||
|
};
|
||||||
|
let n_per_block = bn * tn;
|
||||||
|
let n_blocks = ncols.div_ceil(n_per_block);
|
||||||
|
if n_blocks > 1 {
|
||||||
|
multi_block_sort(
|
||||||
|
device, ep, kernels, dtype, bn, tn, n_blocks, nrows, ncols, src, dst,
|
||||||
|
)?
|
||||||
|
} else {
|
||||||
|
block_sort(device, ep, kernels, dtype, bn, tn, nrows, ncols, src, dst)?
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -605,6 +605,69 @@ fn affine_strided() {
|
|||||||
assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]);
|
assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn run_mlx_sort<T: Clone>(v: &[T], ncols: usize) -> Vec<u32> {
|
||||||
|
let nrows = v.len() / ncols;
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
|
||||||
|
let input = new_buffer(&device, v);
|
||||||
|
let indexes = vec![0u32; v.len()];
|
||||||
|
let output = new_buffer(&device, &indexes);
|
||||||
|
|
||||||
|
call_mlx_arg_sort(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
DType::F32,
|
||||||
|
nrows,
|
||||||
|
ncols,
|
||||||
|
BufferOffset::zero_offset(&input),
|
||||||
|
&output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
read_to_vec(&output, v.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn mlx_sort() {
|
||||||
|
use rand::SeedableRng;
|
||||||
|
use rand_distr::Distribution;
|
||||||
|
|
||||||
|
let input: Vec<_> = (0..8).map(|v| v as f32).collect();
|
||||||
|
let result = run_mlx_sort(&input, 4);
|
||||||
|
assert_eq!(result, [0, 1, 2, 3, 0, 1, 2, 3]);
|
||||||
|
let input: Vec<_> = (0..8).rev().map(|v| v as f32).collect();
|
||||||
|
let result = run_mlx_sort(&input, 4);
|
||||||
|
assert_eq!(result, [3, 2, 1, 0, 3, 2, 1, 0]);
|
||||||
|
let input: Vec<_> = (0..1000).rev().map(|v| v as f32).collect();
|
||||||
|
let result = run_mlx_sort(&input, 200);
|
||||||
|
let out: Vec<_> = (0..200).rev().collect();
|
||||||
|
assert_eq!(&result[..200], out);
|
||||||
|
assert_eq!(&result[200..400], out);
|
||||||
|
assert_eq!(&result[400..600], out);
|
||||||
|
assert_eq!(&result[600..800], out);
|
||||||
|
assert_eq!(&result[800..], out);
|
||||||
|
|
||||||
|
// Multi-block test
|
||||||
|
let ncols = 16000;
|
||||||
|
let mut rng = rand::rngs::StdRng::seed_from_u64(299792458);
|
||||||
|
let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
|
||||||
|
let input: Vec<f32> = (0..ncols * 16).map(|_| normal.sample(&mut rng)).collect();
|
||||||
|
let result = run_mlx_sort(&input, ncols);
|
||||||
|
for start in 0..16 {
|
||||||
|
let slice = &input[start * ncols..(start + 1) * ncols];
|
||||||
|
let result = &result[start * ncols..(start + 1) * ncols];
|
||||||
|
let mut perm: Vec<usize> = (0..ncols).collect();
|
||||||
|
perm.sort_by(|i1, i2| slice[*i1].total_cmp(&slice[*i2]));
|
||||||
|
let perm: Vec<_> = perm.into_iter().map(|v| v as u32).collect();
|
||||||
|
assert_eq!(perm, result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn index_select() {
|
fn index_select() {
|
||||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||||
|
Reference in New Issue
Block a user