Metal: Improved reduce and softmax (#1819)

* Improve reduce perf and add contiguous impl

* Improve arg reduce and add contiguous impl

* Improve softmax kernel. 33%-39% higher thrpt

* fmt

* Fixed all bugs. Improved code quality. Added tests.

* Stash for debugging

* Stash for debugging 2

* Fixing argmax bug and improve performance

Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com>

* Fix test and add is_valid_simgroup_reduce_type trait

* Online softmax. Improved threadgroup reduce. Tidying up a bit.

* Remove redundant threadgroup_barrier from arg reduce

* Mostly tidying up. Some improvements

* Simplify indexed struct

* tidying

* Reuse operation operator instead of passing it in as a parameter

* Fix how operators are applied to indexed<vec<T,N>>

* Vectorized load. Scalar block reduce. Hitting max throughput for f32 reduce.

* Vectorized load for online softmax. Involves a reinterpret_cast of src which may be suboptimal.

* Metal as_type casting vec<bfloat, N> -> vec<float, N/2> for simd and fast math

* Use constant for input instead of const device. Fix strided reduce.

* Use contiguous reduce in tests

* Rename finalize -> to_scalar

* Support integer types max/min (switch with trait-inferred impl later)

* Was worried I was skipping work -> shuffling the 1D test cases

* Add build.rs to avoid metal kernel jit compile overhead

* Improve build. Extract utils

* Compile metal kernels for both macos and ios

* Fixed over xmas and then forgot about it

* Add calculate_reduce_threads util

* Remove old reduce.metal

* Improve f16/bf16 softmax precision by accumulating in f32

* Remove build.rs (for now)

* Move softmax bench to candle-nn

* Remove redundant thread calc util fn

* Use uint over ushort for indices etc

* Use fast exp in MDReduceOp

* Remove nested metal define for softmax

* Fix some clippy lint.

---------

Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com>
Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
ivarflakstad
2025-02-08 07:27:01 +01:00
committed by GitHub
parent 0af3e428ec
commit 7c2449f623
12 changed files with 1521 additions and 357 deletions

View File

@ -5,14 +5,12 @@ use metal::{
use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;
pub mod mlx_gemm;
pub mod sort;
pub mod utils;
pub use utils::BufferOffset;
pub use mlx_gemm::{call_mlx_gemm, GemmDType};
pub use sort::{call_arg_sort, call_mlx_arg_sort};
pub use utils::BufferOffset;
use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};
const AFFINE: &str = include_str!("affine.metal");
@ -176,7 +174,7 @@ pub enum MetalKernelError {
LockError(String),
#[error("Error while loading library: {0}")]
LoadLibraryError(String),
#[error("Error while loading function: {0:?}")]
#[error("Error while loading function: {0}")]
LoadFunctionError(String),
#[error("Failed to create compute function")]
FailedToCreateComputeFunction,
@ -635,19 +633,31 @@ pub fn call_reduce_contiguous(
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
shape: &[usize],
out_length: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let length = shape.iter().product::<usize>();
let num_dims = shape.len();
let work_per_threadgroup = length / out_length;
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, elements_to_sum, &input, output));
set_params!(
encoder,
(
length,
num_dims,
shape,
work_per_threadgroup,
&input,
output
)
);
let thread_group_count = MTLSize {
width: out_length as u64,
@ -657,9 +667,8 @@ pub fn call_reduce_contiguous(
let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
(elements_to_sum as u64).div_ceil(2),
)
.next_power_of_two();
(work_per_threadgroup / 2).next_power_of_two() as NSUInteger,
);
let thread_group_size = MTLSize {
width,
@ -686,8 +695,9 @@ pub fn call_reduce_strided(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let length: usize = shape.iter().product();
let num_dims = shape.len();
let work_per_threadgroup = length / out_length;
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
@ -695,7 +705,15 @@ pub fn call_reduce_strided(
set_params!(
encoder,
(shape.len(), shape, strides, elements_to_sum, &input, output)
(
length,
num_dims,
shape,
strides,
work_per_threadgroup,
&input,
output
)
);
let thread_group_count = MTLSize {
@ -706,16 +724,14 @@ pub fn call_reduce_strided(
let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
elements_to_sum as u64,
)
.next_power_of_two();
(work_per_threadgroup / 2).next_power_of_two() as NSUInteger,
);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
@ -729,11 +745,13 @@ pub fn call_last_softmax(
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
elements_to_sum: usize,
elements: usize,
input: &Buffer,
input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let work_per_threadgroup = elements;
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
@ -741,29 +759,27 @@ pub fn call_last_softmax(
set_params!(
encoder,
(length, elements_to_sum, (input, input_offset), output)
(length, work_per_threadgroup, (input, input_offset), output)
);
let out_length = length / elements_to_sum;
let out_length = length / work_per_threadgroup;
let thread_group_count = MTLSize {
width: out_length as u64,
width: out_length as NSUInteger,
height: 1,
depth: 1,
};
let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
elements_to_sum as u64,
)
.next_power_of_two();
(work_per_threadgroup / 2).next_power_of_two() as NSUInteger,
);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);