Files
candle/candle-metal-kernels/src/sort.rs
Ivar Flakstad 6210fbe9d8 mlx fp8 gemm
2025-05-06 09:55:45 +02:00

299 lines
9.2 KiB
Rust

use crate::utils::{BufferOffset, EncoderProvider};
use crate::{set_params, DType, Kernels, MetalKernelError, Source};
use metal::{Buffer, ComputeCommandEncoderRef, Device, MTLResourceOptions, MTLSize};
#[allow(clippy::too_many_arguments)]
pub fn call_arg_sort(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
nrows: usize,
ncols: usize,
ncols_pad: usize,
src: BufferOffset,
dst: &Buffer,
) -> Result<(), crate::MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
let thread_group_count = MTLSize {
width: 1,
height: nrows as u64,
depth: 1,
};
let thread_group_size = MTLSize {
width: ncols_pad as u64,
height: 1,
depth: 1,
};
encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
fn mlx_dtype_str(dtype: DType) -> &'static str {
match dtype {
DType::FP8(_) => todo!(),
DType::U8 => "uint8",
DType::U32 => "uint32",
DType::I64 => "int64",
DType::F16 => "float16",
DType::BF16 => "bfloat16",
DType::F32 => "float32",
}
}
#[allow(clippy::too_many_arguments)]
pub fn multi_block_sort(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
dtype: DType,
bn: usize,
tn: usize,
nblocks: usize,
nrows: usize,
ncols: usize,
src: BufferOffset,
dst: &Buffer,
) -> Result<(), MetalKernelError> {
let dtype_str = mlx_dtype_str(dtype);
// Do allocations
let el_count = nrows * ncols;
let bytes_len = (el_count * dtype.size_in_bytes()) as u64;
let mut dev_vals_0 = device.new_buffer(bytes_len, MTLResourceOptions::StorageModePrivate);
let mut dev_vals_1 = device.new_buffer(bytes_len, MTLResourceOptions::StorageModePrivate);
let mut dev_idxs_0 =
device.new_buffer(el_count as u64 * 4, MTLResourceOptions::StorageModePrivate);
let mut dev_idxs_1 =
device.new_buffer(el_count as u64 * 4, MTLResourceOptions::StorageModePrivate);
let mut block_partitions = device.new_buffer(
(nrows * (nblocks + 1)) as u64 * 4,
MTLResourceOptions::StorageModePrivate,
);
// Prepare command encoder
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
// Do blockwise sort
{
let name = format!("sort_mbsort_{dtype_str}_uint32_bn{bn}_tn{tn}");
let pipeline = kernels.load_pipeline(device, Source::MlxSort, name)?;
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
&src,
&mut dev_vals_0,
&mut dev_idxs_0,
/* size_sorted_axis */ ncols as i32,
/* stride_sorted_axis */ 1i32,
/* nc_dim */ 1i32,
/* nc_shape */ nrows as i32,
/* nc_str */ ncols as i32
)
);
let thread_group_count = MTLSize {
width: nblocks as u64,
height: nrows as u64,
depth: 1,
};
let thread_group_size = MTLSize {
width: bn as u64,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
}
// Do merges
let mut ping = false;
let mut merge_tiles = 2;
let n_thr_per_group = usize::min(nblocks + 1, 1024);
let partition_name = format!("partition_mbsort_{dtype_str}_uint32_bn{bn}_tn{tn}");
let merge_name = format!("merge_mbsort_float32_uint32_bn{bn}_tn{tn}");
while merge_tiles / 2 < nblocks {
let (dev_vals_in, dev_vals_out) = if ping {
(&mut dev_vals_1, &mut dev_vals_0)
} else {
(&mut dev_vals_0, &mut dev_vals_1)
};
let (dev_idxs_in, dev_idxs_out) = if ping {
(&mut dev_idxs_1, &mut dev_idxs_0)
} else {
(&mut dev_idxs_0, &mut dev_idxs_1)
};
ping = !ping;
// Do partition
{
let pipeline =
kernels.load_pipeline(device, Source::MlxSort, partition_name.clone())?;
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
&mut block_partitions,
&mut *dev_vals_in,
&mut *dev_idxs_in,
/* size_sorted_axis */ ncols as i32,
/* merge_tiles */ merge_tiles as i32,
/* n_blocks */ nblocks as i32
)
);
let thread_group_count = MTLSize {
width: 1,
height: nrows as u64,
depth: 1,
};
let thread_group_size = MTLSize {
width: n_thr_per_group as u64,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
}
// Do merge
{
let pipeline = kernels.load_pipeline(device, Source::MlxSort, merge_name.clone())?;
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
&block_partitions,
&*dev_vals_in,
&*dev_idxs_in,
&*dev_vals_out,
&*dev_idxs_out,
/* size_sorted_axis */ ncols as i32,
/* merge_tiles */ merge_tiles as i32,
/* n_blocks */ nblocks as i32
)
);
let thread_group_count = MTLSize {
width: nblocks as u64,
height: nrows as u64,
depth: 1,
};
let thread_group_size = MTLSize {
width: bn as u64,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
}
merge_tiles *= 2;
}
let dev_idxs_out = if ping {
&mut dev_idxs_1
} else {
&mut dev_idxs_0
};
// Copy output with appropriate strides
let copy_kernel = match dtype {
DType::FP8(_) => todo!(),
DType::U8 => crate::copy2d::U8,
DType::U32 => crate::copy2d::U32,
DType::I64 => crate::copy2d::I64,
DType::BF16 => crate::copy2d::BFLOAT,
DType::F16 => crate::copy2d::HALF,
DType::F32 => crate::copy2d::FLOAT,
};
crate::call_copy2d(
device,
encoder,
kernels,
copy_kernel,
dev_idxs_out,
dst,
/* d1 */ nrows,
/* d2 */ ncols,
/* src_s */ ncols,
/* dst_s */ ncols,
/* src_o_in_bytes */ 0,
/*dst_o_in_bytes */ 0,
)?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn block_sort(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
dtype: DType,
bn: usize,
tn: usize,
nrows: usize,
ncols: usize,
src: BufferOffset,
dst: &Buffer,
) -> Result<(), MetalKernelError> {
let dtype_str = mlx_dtype_str(dtype);
let name = format!("carg_block_sort_{dtype_str}_uint32_bn{bn}_tn{tn}");
let pipeline = kernels.load_pipeline(device, Source::MlxSort, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
&src,
dst,
ncols as i32,
1i32,
1i32,
ncols as i32,
ncols as i32
)
);
let thread_group_count = MTLSize {
width: 1,
height: nrows as u64,
depth: 1,
};
let thread_group_size = MTLSize {
width: bn as u64,
height: 1,
depth: 1,
};
encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_mlx_arg_sort(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
dtype: DType,
nrows: usize,
ncols: usize,
src: BufferOffset,
dst: &Buffer,
) -> Result<(), MetalKernelError> {
let tn = 8;
let bn = match ncols.div_ceil(tn) {
257.. if dtype.size_in_bytes() <= 4 => 512,
129.. => 256,
0..129 => 128,
};
let n_per_block = bn * tn;
let n_blocks = ncols.div_ceil(n_per_block);
if n_blocks > 1 {
multi_block_sort(
device, ep, kernels, dtype, bn, tn, n_blocks, nrows, ncols, src, dst,
)?
} else {
block_sort(device, ep, kernels, dtype, bn, tn, nrows, ncols, src, dst)?
}
Ok(())
}