mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Metal: Improved reduce and softmax (#1819)
* Improve reduce perf and add contiguous impl * Improve arg reduce and add contiguous impl * Improve softmax kernel. 33%-39% higher thrpt * fmt * Fixed all bugs. Improved code quality. Added tests. * Stash for debugging * Stash for debugging 2 * Fixing argmax bug and improve performance Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com> * Fix test and add is_valid_simgroup_reduce_type trait * Online softmax. Improved threadgroup reduce. Tidying up a bit. * Remove redundant threadgroup_barrier from arg reduce * Mostly tidying up. Some improvements * Simplify indexed struct * tidying * Reuse operation operator instead of passing it in as a parameter * Fix how operators are applied to indexed<vec<T,N>> * Vectorized load. Scalar block reduce. Hitting max throughput for f32 reduce. * Vectorized load for online softmax. Involves a reinterpret_cast of src which may be suboptimal. * Metal as_type casting vec<bfloat, N> -> vec<float, N/2> for simd and fast math * Use constant for input instead of const device. Fix strided reduce. * Use contiguous reduce in tests * Rename finalize -> to_scalar * Support integer types max/min (switch with trait-inferred impl later) * Was worried I was skipping work -> shuffling the 1D test cases * Add build.rs to avoid metal kernel jit compile overhead * Improve build. Extract utils * Compile metal kernels for both macos and ios * Fixed over xmas and then forgot about it * Add calculate_reduce_threads util * Remove old reduce.metal * Improve f16/bf16 softmax precision by accumulating in f32 * Remove build.rs (for now) * Move softmax bench to candle-nn * Remove redundant thread calc util fn * Use uint over ushort for indices etc * Use fast exp in MDReduceOp * Remove nested metal define for softmax * Fix some clippy lint. --------- Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com> Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -5,14 +5,12 @@ use metal::{
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::sync::RwLock;
|
||||
|
||||
pub mod mlx_gemm;
|
||||
pub mod sort;
|
||||
pub mod utils;
|
||||
pub use utils::BufferOffset;
|
||||
|
||||
pub use mlx_gemm::{call_mlx_gemm, GemmDType};
|
||||
pub use sort::{call_arg_sort, call_mlx_arg_sort};
|
||||
pub use utils::BufferOffset;
|
||||
use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};
|
||||
|
||||
const AFFINE: &str = include_str!("affine.metal");
|
||||
@ -176,7 +174,7 @@ pub enum MetalKernelError {
|
||||
LockError(String),
|
||||
#[error("Error while loading library: {0}")]
|
||||
LoadLibraryError(String),
|
||||
#[error("Error while loading function: {0:?}")]
|
||||
#[error("Error while loading function: {0}")]
|
||||
LoadFunctionError(String),
|
||||
#[error("Failed to create compute function")]
|
||||
FailedToCreateComputeFunction,
|
||||
@ -635,19 +633,31 @@ pub fn call_reduce_contiguous(
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
kernel_name: &'static str,
|
||||
length: usize,
|
||||
shape: &[usize],
|
||||
out_length: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let length = shape.iter().product::<usize>();
|
||||
let num_dims = shape.len();
|
||||
let work_per_threadgroup = length / out_length;
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let elements_to_sum = length / out_length;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, elements_to_sum, &input, output));
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
length,
|
||||
num_dims,
|
||||
shape,
|
||||
work_per_threadgroup,
|
||||
&input,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: out_length as u64,
|
||||
@ -657,9 +667,8 @@ pub fn call_reduce_contiguous(
|
||||
|
||||
let width = std::cmp::min(
|
||||
pipeline.max_total_threads_per_threadgroup(),
|
||||
(elements_to_sum as u64).div_ceil(2),
|
||||
)
|
||||
.next_power_of_two();
|
||||
(work_per_threadgroup / 2).next_power_of_two() as NSUInteger,
|
||||
);
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
@ -686,8 +695,9 @@ pub fn call_reduce_strided(
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let length: usize = shape.iter().product();
|
||||
let num_dims = shape.len();
|
||||
let work_per_threadgroup = length / out_length;
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let elements_to_sum = length / out_length;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
@ -695,7 +705,15 @@ pub fn call_reduce_strided(
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(shape.len(), shape, strides, elements_to_sum, &input, output)
|
||||
(
|
||||
length,
|
||||
num_dims,
|
||||
shape,
|
||||
strides,
|
||||
work_per_threadgroup,
|
||||
&input,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
@ -706,16 +724,14 @@ pub fn call_reduce_strided(
|
||||
|
||||
let width = std::cmp::min(
|
||||
pipeline.max_total_threads_per_threadgroup(),
|
||||
elements_to_sum as u64,
|
||||
)
|
||||
.next_power_of_two();
|
||||
(work_per_threadgroup / 2).next_power_of_two() as NSUInteger,
|
||||
);
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
@ -729,11 +745,13 @@ pub fn call_last_softmax(
|
||||
kernels: &Kernels,
|
||||
kernel_name: &'static str,
|
||||
length: usize,
|
||||
elements_to_sum: usize,
|
||||
elements: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let work_per_threadgroup = elements;
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
@ -741,29 +759,27 @@ pub fn call_last_softmax(
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(length, elements_to_sum, (input, input_offset), output)
|
||||
(length, work_per_threadgroup, (input, input_offset), output)
|
||||
);
|
||||
|
||||
let out_length = length / elements_to_sum;
|
||||
let out_length = length / work_per_threadgroup;
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: out_length as u64,
|
||||
width: out_length as NSUInteger,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let width = std::cmp::min(
|
||||
pipeline.max_total_threads_per_threadgroup(),
|
||||
elements_to_sum as u64,
|
||||
)
|
||||
.next_power_of_two();
|
||||
(work_per_threadgroup / 2).next_power_of_two() as NSUInteger,
|
||||
);
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
|
Reference in New Issue
Block a user