Fixing the kernels + launches to make them faster.

Cool work by @ivarflakstad

Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2023-11-10 11:14:51 +01:00
parent 02c2ec2c71
commit cc26cce23c
6 changed files with 69 additions and 162 deletions

View File

@ -1,7 +1,7 @@
#![allow(clippy::too_many_arguments)]
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor,
Device, Function, Library, MTLSize,
ComputePipelineState, Device, Function, Library, MTLSize,
};
use std::collections::HashMap;
use std::ffi::c_void;
@ -15,6 +15,24 @@ const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal");
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
let size = length as u64;
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
let count = (size + width - 1) / width;
let thread_group_count = MTLSize {
width: count,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
(thread_group_count, thread_group_size)
}
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
<P as EncoderParam>::set_param(encoder, position, data)
}
@ -257,19 +275,7 @@ pub fn call_unary_contiguous(
set_params!(encoder, (length, input, output));
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
@ -314,17 +320,7 @@ pub fn call_unary_strided(
);
let width: usize = shape.iter().product();
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width as u64),
height: 1,
depth: 1,
};
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -358,18 +354,7 @@ pub fn call_binary_contiguous(
set_params!(encoder, (length, left, right, output));
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -421,17 +406,7 @@ pub fn call_binary_strided(
)
);
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width as u64),
height: 1,
depth: 1,
};
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -464,18 +439,7 @@ pub fn call_cast_contiguous(
set_params!(encoder, (length, input, output));
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -608,19 +572,7 @@ pub fn call_affine(
set_params!(encoder, (size, mul, add, input, output));
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
@ -672,18 +624,7 @@ pub fn call_where_cond_strided(
)
);
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -730,19 +671,9 @@ pub fn call_index_select(
)
);
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64);
let grid_size = MTLSize {
width: (dst_el as u64 + width - 1) / width,
height: 1,
depth: 1,
};
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(grid_size, thread_group_size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}