Fixing the kernels + launches to make them faster.

Cool work by @ivarflakstad

Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2023-11-10 11:14:51 +01:00
parent 02c2ec2c71
commit cc26cce23c
6 changed files with 69 additions and 162 deletions

View File

@ -24,17 +24,14 @@ kernel void FN_NAME( \
constant float &add, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
const TYPENAME m = TYPENAME(mul); \
const TYPENAME a = TYPENAME(add); \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = input[i] * m + a; \
} \
output[id] = input[id] * m + a; \
} \
AFFINE(affine_float, float)

View File

@ -23,17 +23,14 @@ kernel void FN_NAME( \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
TYPENAME x = left[i]; \
TYPENAME y = right[i]; \
output[i] = OUT_TYPENAME(FN); \
if (thread_position_in_grid >= dim) { \
return; \
} \
TYPENAME x = left[thread_position_in_grid]; \
TYPENAME y = right[thread_position_in_grid]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@ -44,17 +41,14 @@ kernel void FN_NAME_STRIDED( \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
TYPENAME x = left[get_strided_index(i, num_dims, dims, left_strides)]; \
TYPENAME y = left[get_strided_index(i, num_dims, dims, right_strides)]; \
output[i] = OUT_TYPENAME(FN); \
if (thread_position_in_grid >= dim) { \
return; \
} \
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
}
#define BINARY_OP(FN, NAME) \

View File

@ -23,15 +23,12 @@ kernel void FN_NAME( \
constant size_t &dim, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = RIGHT_TYPENAME(input[i]); \
if (thread_position_in_grid >= dim) { \
return; \
} \
output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@ -40,17 +37,13 @@ kernel void FN_NAME_STRIDED( \
constant size_t *strides, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
uint i [[ thread_position_in_grid ]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
if (i >= dim) { \
return; \
} \
}
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
} \
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)

View File

@ -2,7 +2,7 @@
using namespace metal;
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
kernel void NAME( \
constant size_t &dst_size, \
constant size_t &left_size, \
constant size_t &src_dim_size, \
@ -42,12 +42,9 @@ void index_add(
constant uint &dst_dim_size,
constant uint &right_size,
uint threadgroup_size [[threads_per_threadgroup]],
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]],
uint thread_index [[thread_index_in_threadgroup]]
uint gid [[ thread_position_in_grid ]] \
) {
const uint gid = thread_index + (threadgroup_position_in_grid * threadgroup_size);
if (gid >= left_size * right_size) {
return;
}
@ -73,14 +70,13 @@ kernel void FN_NAME( \
constant uint &left_size, \
constant uint &dst_dim_size, \
constant uint &right_size, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \
uint gid [[ thread_position_in_grid ]] \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \
INDEX_OP(is_u32_f32, uint, float)
#if __METAL_VERSION__ >= 310
IA_OP(bfloat, int64_t, ia_i64_bf16)
IA_OP(bfloat, uint32_t, ia_u32_bf16)

View File

@ -1,7 +1,7 @@
#![allow(clippy::too_many_arguments)]
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor,
Device, Function, Library, MTLSize,
ComputePipelineState, Device, Function, Library, MTLSize,
};
use std::collections::HashMap;
use std::ffi::c_void;
@ -15,6 +15,24 @@ const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal");
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
let size = length as u64;
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
let count = (size + width - 1) / width;
let thread_group_count = MTLSize {
width: count,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
(thread_group_count, thread_group_size)
}
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
<P as EncoderParam>::set_param(encoder, position, data)
}
@ -257,19 +275,7 @@ pub fn call_unary_contiguous(
set_params!(encoder, (length, input, output));
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
@ -314,17 +320,7 @@ pub fn call_unary_strided(
);
let width: usize = shape.iter().product();
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width as u64),
height: 1,
depth: 1,
};
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -358,18 +354,7 @@ pub fn call_binary_contiguous(
set_params!(encoder, (length, left, right, output));
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -421,17 +406,7 @@ pub fn call_binary_strided(
)
);
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width as u64),
height: 1,
depth: 1,
};
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -464,18 +439,7 @@ pub fn call_cast_contiguous(
set_params!(encoder, (length, input, output));
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -608,19 +572,7 @@ pub fn call_affine(
set_params!(encoder, (size, mul, add, input, output));
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
@ -672,18 +624,7 @@ pub fn call_where_cond_strided(
)
);
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -730,19 +671,9 @@ pub fn call_index_select(
)
);
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64);
let grid_size = MTLSize {
width: (dst_el as u64 + width - 1) / width,
height: 1,
depth: 1,
};
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(grid_size, thread_group_size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}

View File

@ -27,15 +27,12 @@ kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = TYPENAME(FN(input[i])); \
if (thread_position_in_grid >= dim) { \
return; \
} \
output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@ -44,15 +41,12 @@ kernel void FN_NAME_STRIDED( \
constant size_t *strides, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \
if (thread_position_in_grid >= dim) { \
return; \
} \
output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \
}
#define UNARY_OP(NAME) \
@ -79,4 +73,6 @@ BFLOAT_UNARY_OP(sqr)
BFLOAT_UNARY_OP(sqrt)
BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp)
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
#endif