Fixing the kernels + launches to make them faster.

Cool work by @ivarflakstad

Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2023-11-10 11:14:51 +01:00
parent 02c2ec2c71
commit cc26cce23c
6 changed files with 69 additions and 162 deletions

View File

@ -24,17 +24,14 @@ kernel void FN_NAME( \
constant float &add, \ constant float &add, \
device const TYPENAME *input, \ device const TYPENAME *input, \
device TYPENAME *output, \ device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \ uint id [[ thread_position_in_grid ]] \
uint thread_index [[thread_index_in_threadgroup]] \
) { \ ) { \
if (id >= dim) { \
return; \
} \
const TYPENAME m = TYPENAME(mul); \ const TYPENAME m = TYPENAME(mul); \
const TYPENAME a = TYPENAME(add); \ const TYPENAME a = TYPENAME(add); \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ output[id] = input[id] * m + a; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = input[i] * m + a; \
} \
} \ } \
AFFINE(affine_float, float) AFFINE(affine_float, float)

View File

@ -23,17 +23,14 @@ kernel void FN_NAME( \
device const TYPENAME *left, \ device const TYPENAME *left, \
device const TYPENAME *right, \ device const TYPENAME *right, \
device TYPENAME *output, \ device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \ uint thread_position_in_grid [[ thread_position_in_grid ]] \
uint thread_index [[thread_index_in_threadgroup]] \
) { \ ) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ if (thread_position_in_grid >= dim) { \
const size_t start = thread_index * length; \ return; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
TYPENAME x = left[i]; \
TYPENAME y = right[i]; \
output[i] = OUT_TYPENAME(FN); \
} \ } \
TYPENAME x = left[thread_position_in_grid]; \
TYPENAME y = right[thread_position_in_grid]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
}\ }\
kernel void FN_NAME_STRIDED( \ kernel void FN_NAME_STRIDED( \
constant size_t &dim, \ constant size_t &dim, \
@ -44,17 +41,14 @@ kernel void FN_NAME_STRIDED( \
device const TYPENAME *left, \ device const TYPENAME *left, \
device const TYPENAME *right, \ device const TYPENAME *right, \
device TYPENAME *output, \ device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \ uint thread_position_in_grid [[ thread_position_in_grid ]] \
uint thread_index [[thread_index_in_threadgroup]] \
) { \ ) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ if (thread_position_in_grid >= dim) { \
const size_t start = thread_index * length; \ return; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
TYPENAME x = left[get_strided_index(i, num_dims, dims, left_strides)]; \
TYPENAME y = left[get_strided_index(i, num_dims, dims, right_strides)]; \
output[i] = OUT_TYPENAME(FN); \
} \ } \
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
} }
#define BINARY_OP(FN, NAME) \ #define BINARY_OP(FN, NAME) \

View File

@ -23,15 +23,12 @@ kernel void FN_NAME( \
constant size_t &dim, \ constant size_t &dim, \
device const LEFT_TYPENAME *input, \ device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \ device RIGHT_TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \ uint thread_position_in_grid [[ thread_position_in_grid ]] \
uint thread_index [[thread_index_in_threadgroup]] \
) { \ ) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ if (thread_position_in_grid >= dim) { \
const size_t start = thread_index * length; \ return; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = RIGHT_TYPENAME(input[i]); \
} \ } \
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
} \ } \
kernel void FN_NAME_STRIDED( \ kernel void FN_NAME_STRIDED( \
constant size_t &dim, \ constant size_t &dim, \
@ -40,17 +37,13 @@ kernel void FN_NAME_STRIDED( \
constant size_t *strides, \ constant size_t *strides, \
device const LEFT_TYPENAME *input, \ device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \ device RIGHT_TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \ uint i [[ thread_position_in_grid ]] \
uint thread_index [[thread_index_in_threadgroup]] \
) { \ ) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ if (i >= dim) { \
const size_t start = thread_index * length; \ return; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
} \ } \
} output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
} \
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float) CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)

View File

@ -2,7 +2,7 @@
using namespace metal; using namespace metal;
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \ # define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \ kernel void NAME( \
constant size_t &dst_size, \ constant size_t &dst_size, \
constant size_t &left_size, \ constant size_t &left_size, \
constant size_t &src_dim_size, \ constant size_t &src_dim_size, \
@ -42,12 +42,9 @@ void index_add(
constant uint &dst_dim_size, constant uint &dst_dim_size,
constant uint &right_size, constant uint &right_size,
uint threadgroup_size [[threads_per_threadgroup]], uint gid [[ thread_position_in_grid ]] \
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]],
uint thread_index [[thread_index_in_threadgroup]]
) { ) {
const uint gid = thread_index + (threadgroup_position_in_grid * threadgroup_size);
if (gid >= left_size * right_size) { if (gid >= left_size * right_size) {
return; return;
} }
@ -73,14 +70,13 @@ kernel void FN_NAME( \
constant uint &left_size, \ constant uint &left_size, \
constant uint &dst_dim_size, \ constant uint &dst_dim_size, \
constant uint &right_size, \ constant uint &right_size, \
uint threadgroup_size [[threads_per_threadgroup]], \ uint gid [[ thread_position_in_grid ]] \
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ ) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \
uint thread_index [[thread_index_in_threadgroup]] \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \
INDEX_OP(is_u32_f32, uint, float) INDEX_OP(is_u32_f32, uint, float)
#if __METAL_VERSION__ >= 310 #if __METAL_VERSION__ >= 310
IA_OP(bfloat, int64_t, ia_i64_bf16) IA_OP(bfloat, int64_t, ia_i64_bf16)
IA_OP(bfloat, uint32_t, ia_u32_bf16) IA_OP(bfloat, uint32_t, ia_u32_bf16)

View File

@ -1,7 +1,7 @@
#![allow(clippy::too_many_arguments)] #![allow(clippy::too_many_arguments)]
use metal::{ use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor, Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor,
Device, Function, Library, MTLSize, ComputePipelineState, Device, Function, Library, MTLSize,
}; };
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::c_void; use std::ffi::c_void;
@ -15,6 +15,24 @@ const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal"); const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal"); const REDUCE: &str = include_str!("reduce.metal");
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
let size = length as u64;
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
let count = (size + width - 1) / width;
let thread_group_count = MTLSize {
width: count,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
(thread_group_count, thread_group_size)
}
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) { fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
<P as EncoderParam>::set_param(encoder, position, data) <P as EncoderParam>::set_param(encoder, position, data)
} }
@ -257,19 +275,7 @@ pub fn call_unary_contiguous(
set_params!(encoder, (length, input, output)); set_params!(encoder, (length, input, output));
let thread_group_count = MTLSize { let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
@ -314,17 +320,7 @@ pub fn call_unary_strided(
); );
let width: usize = shape.iter().product(); let width: usize = shape.iter().product();
let thread_group_count = MTLSize { let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
width: 1,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width as u64),
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();
@ -358,18 +354,7 @@ pub fn call_binary_contiguous(
set_params!(encoder, (length, left, right, output)); set_params!(encoder, (length, left, right, output));
let thread_group_count = MTLSize { let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();
@ -421,17 +406,7 @@ pub fn call_binary_strided(
) )
); );
let thread_group_count = MTLSize { let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
width: 1,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width as u64),
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();
@ -464,18 +439,7 @@ pub fn call_cast_contiguous(
set_params!(encoder, (length, input, output)); set_params!(encoder, (length, input, output));
let thread_group_count = MTLSize { let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();
@ -608,19 +572,7 @@ pub fn call_affine(
set_params!(encoder, (size, mul, add, input, output)); set_params!(encoder, (size, mul, add, input, output));
let thread_group_count = MTLSize { let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
@ -672,18 +624,7 @@ pub fn call_where_cond_strided(
) )
); );
let thread_group_count = MTLSize { let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();
@ -730,19 +671,9 @@ pub fn call_index_select(
) )
); );
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let grid_size = MTLSize {
width: (dst_el as u64 + width - 1) / width,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize { encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(grid_size, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }

View File

@ -27,15 +27,12 @@ kernel void FN_NAME( \
constant size_t &dim, \ constant size_t &dim, \
device const TYPENAME *input, \ device const TYPENAME *input, \
device TYPENAME *output, \ device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \ uint thread_position_in_grid [[ thread_position_in_grid ]] \
uint thread_index [[thread_index_in_threadgroup]] \
) { \ ) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ if (thread_position_in_grid >= dim) { \
const size_t start = thread_index * length; \ return; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = TYPENAME(FN(input[i])); \
} \ } \
output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \
}\ }\
kernel void FN_NAME_STRIDED( \ kernel void FN_NAME_STRIDED( \
constant size_t &dim, \ constant size_t &dim, \
@ -44,15 +41,12 @@ kernel void FN_NAME_STRIDED( \
constant size_t *strides, \ constant size_t *strides, \
device const TYPENAME *input, \ device const TYPENAME *input, \
device TYPENAME *output, \ device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \ uint thread_position_in_grid [[ thread_position_in_grid ]] \
uint thread_index [[thread_index_in_threadgroup]] \
) { \ ) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ if (thread_position_in_grid >= dim) { \
const size_t start = thread_index * length; \ return; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \
} \ } \
output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \
} }
#define UNARY_OP(NAME) \ #define UNARY_OP(NAME) \
@ -79,4 +73,6 @@ BFLOAT_UNARY_OP(sqr)
BFLOAT_UNARY_OP(sqrt) BFLOAT_UNARY_OP(sqrt)
BFLOAT_UNARY_OP(neg) BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp) BFLOAT_UNARY_OP(exp)
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
#endif #endif