Metal: Improved reduce and softmax (#1819)

* Improve reduce perf and add contiguous impl

* Improve arg reduce and add contiguous impl

* Improve softmax kernel. 33%-39% higher thrpt

* fmt

* Fixed all bugs. Improved code quality. Added tests.

* Stash for debugging

* Stash for debugging 2

* Fixing argmax bug and improve performance

Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com>

* Fix test and add is_valid_simgroup_reduce_type trait

* Online softmax. Improved threadgroup reduce. Tidying up a bit.

* Remove redundant threadgroup_barrier from arg reduce

* Mostly tidying up. Some improvements

* Simplify indexed struct

* tidying

* Reuse operation operator instead of passing it in as a parameter

* Fix how operators are applied to indexed<vec<T,N>>

* Vectorized load. Scalar block reduce. Hitting max throughput for f32 reduce.

* Vectorized load for online softmax. Involves a reinterpret_cast of src which may be suboptimal.

* Metal as_type casting vec<bfloat, N> -> vec<float, N/2> for simd and fast math

* Use constant for input instead of const device. Fix strided reduce.

* Use contiguous reduce in tests

* Rename finalize -> to_scalar

* Support integer types max/min (switch with trait-inferred impl later)

* Was worried I was skipping work -> shuffling the 1D test cases

* Add build.rs to avoid metal kernel jit compile overhead

* Improve build. Extract utils

* Compile metal kernels for both macos and ios

* Fixed over xmas and then forgot about it

* Add calculate_reduce_threads util

* Remove old reduce.metal

* Improve f16/bf16 softmax precision by accumulating in f32

* Remove build.rs (for now)

* Move softmax bench to candle-nn

* Remove redundant thread calc util fn

* Use uint over ushort for indices etc

* Use fast exp in MDReduceOp

* Remove nested metal define for softmax

* Fix some clippy lint.

---------

Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com>
Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
ivarflakstad
2025-02-08 07:27:01 +01:00
committed by GitHub
parent 0af3e428ec
commit 7c2449f623
12 changed files with 1521 additions and 357 deletions

View File

@ -265,6 +265,7 @@ impl BackendStorage for MetalStorage {
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
let device = self.device.clone();
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
// Source dims and strides with the sum dims at the end.
@ -278,13 +279,72 @@ impl BackendStorage for MetalStorage {
stride.push(src_stride[dim_idx]);
}
}
for &dim_idx in sum_dims.iter() {
dims.push(src_dims[dim_idx]);
stride.push(src_stride[dim_idx]);
}
// The reduction loop requires the shared array to be properly initialized and for
// this we want the number of threads to be a power of two.
let reduction_shape = Shape::from(dims.clone());
if layout.is_contiguous() && reduction_shape.is_contiguous(&stride) {
let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false),
(ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false),
(ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false),
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false),
(ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false),
(ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false),
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false),
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false),
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false),
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false),
(ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false),
(ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false),
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false),
(ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false),
(ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false),
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
(k, dtype) => {
crate::bail!("Metal contiguous reduce op {k:?} {dtype:?} not implemented")
}
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let dtype = if return_index { DType::U32 } else { self.dtype };
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_reduce_contiguous(
&device.device,
&command_buffer,
&device.kernels,
name,
src_dims,
dst_el,
src,
&buffer,
)
.map_err(MetalError::from)?;
return Ok(Self::new(buffer, device, dst_el, dtype));
}
let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
@ -316,7 +376,7 @@ impl BackendStorage for MetalStorage {
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
(k, dtype) => crate::bail!("Metal strided reduce op {k:?} {dtype:?} not implemented"),
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?