mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
Add the MLX merge sort kernels (#2751)
* Add some metal sort kernels imported from MLX. * Add another test. * Start adding the multiblock version. * Proper kernel names. * Split out the main metal file. * Multi-block sort. * More sorting. * DType parametrization. * Add a larger test.
This commit is contained in:
296
candle-metal-kernels/src/sort.rs
Normal file
296
candle-metal-kernels/src/sort.rs
Normal file
@ -0,0 +1,296 @@
|
||||
use crate::utils::{BufferOffset, EncoderProvider};
|
||||
use crate::{set_params, DType, Kernels, MetalKernelError, Source};
|
||||
use metal::{Buffer, ComputeCommandEncoderRef, Device, MTLResourceOptions, MTLSize};
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_arg_sort(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
ncols_pad: usize,
|
||||
src: BufferOffset,
|
||||
dst: &Buffer,
|
||||
) -> Result<(), crate::MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: nrows as u64,
|
||||
depth: 1,
|
||||
};
|
||||
let thread_group_size = MTLSize {
|
||||
width: ncols_pad as u64,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
|
||||
encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn mlx_dtype_str(dtype: DType) -> &'static str {
|
||||
match dtype {
|
||||
DType::U8 => "uint8",
|
||||
DType::U32 => "uint32",
|
||||
DType::I64 => "int64",
|
||||
DType::F16 => "float16",
|
||||
DType::BF16 => "bfloat16",
|
||||
DType::F32 => "float32",
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn multi_block_sort(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
dtype: DType,
|
||||
bn: usize,
|
||||
tn: usize,
|
||||
nblocks: usize,
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
src: BufferOffset,
|
||||
dst: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let dtype_str = mlx_dtype_str(dtype);
|
||||
// Do allocations
|
||||
let el_count = nrows * ncols;
|
||||
let bytes_len = (el_count * dtype.size_in_bytes()) as u64;
|
||||
let mut dev_vals_0 = device.new_buffer(bytes_len, MTLResourceOptions::StorageModePrivate);
|
||||
let mut dev_vals_1 = device.new_buffer(bytes_len, MTLResourceOptions::StorageModePrivate);
|
||||
let mut dev_idxs_0 =
|
||||
device.new_buffer(el_count as u64 * 4, MTLResourceOptions::StorageModePrivate);
|
||||
let mut dev_idxs_1 =
|
||||
device.new_buffer(el_count as u64 * 4, MTLResourceOptions::StorageModePrivate);
|
||||
let mut block_partitions = device.new_buffer(
|
||||
(nrows * (nblocks + 1)) as u64 * 4,
|
||||
MTLResourceOptions::StorageModePrivate,
|
||||
);
|
||||
// Prepare command encoder
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
// Do blockwise sort
|
||||
{
|
||||
let name = format!("sort_mbsort_{dtype_str}_uint32_bn{bn}_tn{tn}");
|
||||
let pipeline = kernels.load_pipeline(device, Source::MlxSort, name)?;
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
&src,
|
||||
&mut dev_vals_0,
|
||||
&mut dev_idxs_0,
|
||||
/* size_sorted_axis */ ncols as i32,
|
||||
/* stride_sorted_axis */ 1i32,
|
||||
/* nc_dim */ 1i32,
|
||||
/* nc_shape */ nrows as i32,
|
||||
/* nc_str */ ncols as i32
|
||||
)
|
||||
);
|
||||
let thread_group_count = MTLSize {
|
||||
width: nblocks as u64,
|
||||
height: nrows as u64,
|
||||
depth: 1,
|
||||
};
|
||||
let thread_group_size = MTLSize {
|
||||
width: bn as u64,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
}
|
||||
// Do merges
|
||||
let mut ping = false;
|
||||
let mut merge_tiles = 2;
|
||||
let n_thr_per_group = usize::min(nblocks + 1, 1024);
|
||||
let partition_name = format!("partition_mbsort_{dtype_str}_uint32_bn{bn}_tn{tn}");
|
||||
let merge_name = format!("merge_mbsort_float32_uint32_bn{bn}_tn{tn}");
|
||||
while merge_tiles / 2 < nblocks {
|
||||
let (dev_vals_in, dev_vals_out) = if ping {
|
||||
(&mut dev_vals_1, &mut dev_vals_0)
|
||||
} else {
|
||||
(&mut dev_vals_0, &mut dev_vals_1)
|
||||
};
|
||||
let (dev_idxs_in, dev_idxs_out) = if ping {
|
||||
(&mut dev_idxs_1, &mut dev_idxs_0)
|
||||
} else {
|
||||
(&mut dev_idxs_0, &mut dev_idxs_1)
|
||||
};
|
||||
ping = !ping;
|
||||
// Do partition
|
||||
{
|
||||
let pipeline =
|
||||
kernels.load_pipeline(device, Source::MlxSort, partition_name.clone())?;
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
&mut block_partitions,
|
||||
&mut *dev_vals_in,
|
||||
&mut *dev_idxs_in,
|
||||
/* size_sorted_axis */ ncols as i32,
|
||||
/* merge_tiles */ merge_tiles as i32,
|
||||
/* n_blocks */ nblocks as i32
|
||||
)
|
||||
);
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: nrows as u64,
|
||||
depth: 1,
|
||||
};
|
||||
let thread_group_size = MTLSize {
|
||||
width: n_thr_per_group as u64,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
}
|
||||
// Do merge
|
||||
{
|
||||
let pipeline = kernels.load_pipeline(device, Source::MlxSort, merge_name.clone())?;
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
&block_partitions,
|
||||
&*dev_vals_in,
|
||||
&*dev_idxs_in,
|
||||
&*dev_vals_out,
|
||||
&*dev_idxs_out,
|
||||
/* size_sorted_axis */ ncols as i32,
|
||||
/* merge_tiles */ merge_tiles as i32,
|
||||
/* n_blocks */ nblocks as i32
|
||||
)
|
||||
);
|
||||
let thread_group_count = MTLSize {
|
||||
width: nblocks as u64,
|
||||
height: nrows as u64,
|
||||
depth: 1,
|
||||
};
|
||||
let thread_group_size = MTLSize {
|
||||
width: bn as u64,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
}
|
||||
merge_tiles *= 2;
|
||||
}
|
||||
let dev_idxs_out = if ping {
|
||||
&mut dev_idxs_1
|
||||
} else {
|
||||
&mut dev_idxs_0
|
||||
};
|
||||
// Copy output with appropriate strides
|
||||
let copy_kernel = match dtype {
|
||||
DType::U8 => crate::copy2d::U8,
|
||||
DType::U32 => crate::copy2d::U32,
|
||||
DType::I64 => crate::copy2d::I64,
|
||||
DType::BF16 => crate::copy2d::BFLOAT,
|
||||
DType::F16 => crate::copy2d::HALF,
|
||||
DType::F32 => crate::copy2d::FLOAT,
|
||||
};
|
||||
crate::call_copy2d(
|
||||
device,
|
||||
encoder,
|
||||
kernels,
|
||||
copy_kernel,
|
||||
dev_idxs_out,
|
||||
dst,
|
||||
/* d1 */ nrows,
|
||||
/* d2 */ ncols,
|
||||
/* src_s */ ncols,
|
||||
/* dst_s */ ncols,
|
||||
/* src_o_in_bytes */ 0,
|
||||
/*dst_o_in_bytes */ 0,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn block_sort(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
dtype: DType,
|
||||
bn: usize,
|
||||
tn: usize,
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
src: BufferOffset,
|
||||
dst: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let dtype_str = mlx_dtype_str(dtype);
|
||||
let name = format!("carg_block_sort_{dtype_str}_uint32_bn{bn}_tn{tn}");
|
||||
let pipeline = kernels.load_pipeline(device, Source::MlxSort, name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
&src,
|
||||
dst,
|
||||
ncols as i32,
|
||||
1i32,
|
||||
1i32,
|
||||
ncols as i32,
|
||||
ncols as i32
|
||||
)
|
||||
);
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: nrows as u64,
|
||||
depth: 1,
|
||||
};
|
||||
let thread_group_size = MTLSize {
|
||||
width: bn as u64,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_mlx_arg_sort(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
dtype: DType,
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
src: BufferOffset,
|
||||
dst: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let tn = 8;
|
||||
let bn = match ncols.div_ceil(tn) {
|
||||
257.. if dtype.size_in_bytes() <= 4 => 512,
|
||||
129.. => 256,
|
||||
0..129 => 128,
|
||||
};
|
||||
let n_per_block = bn * tn;
|
||||
let n_blocks = ncols.div_ceil(n_per_block);
|
||||
if n_blocks > 1 {
|
||||
multi_block_sort(
|
||||
device, ep, kernels, dtype, bn, tn, n_blocks, nrows, ncols, src, dst,
|
||||
)?
|
||||
} else {
|
||||
block_sort(device, ep, kernels, dtype, bn, tn, nrows, ncols, src, dst)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user