Finish reduce kernels.

This commit is contained in:
Nicolas Patry
2023-12-17 19:07:00 +01:00
parent 6bc92e63cb
commit 972903021c
6 changed files with 258 additions and 39 deletions

View File

@ -1,5 +1,8 @@
#include <metal_stdlib>
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
@ -63,10 +66,14 @@ BINARY_OP(x + y, add)
BINARY_OP(x - y, sub)
BINARY_OP(x * y, mul)
BINARY_OP(x / y, div)
BINARY_OP(MIN(x, y), min)
BINARY_OP(MAX(x, y), max)
#if __METAL_VERSION__ >= 310
BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub)
BFLOAT_BINARY_OP(x * y, mul)
BFLOAT_BINARY_OP(x / y, div)
BFLOAT_BINARY_OP(MIN(x, y), min)
BFLOAT_BINARY_OP(MAX(x, y), max)
#endif

View File

@ -166,7 +166,7 @@ pub mod unary {
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
}
pub mod binary {
ops!(add, sub, mul, div);
ops!(add, sub, mul, div, min, max);
}
#[derive(thiserror::Error, Debug)]
@ -588,6 +588,64 @@ pub fn call_reduce_contiguous(
Ok(())
}
pub fn call_reduce_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: &'static str,
shape: &[usize],
strides: &[usize],
out_length: usize,
input: &Buffer,
input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let length: usize = shape.iter().product();
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
shape.len(),
shape,
strides,
elements_to_sum,
(input, input_offset),
output
)
);
let thread_group_count = MTLSize {
width: out_length as u64,
height: 1,
depth: 1,
};
let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
elements_to_sum as u64,
)
.next_power_of_two();
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_last_softmax(
device: &Device,

View File

@ -2,6 +2,7 @@
using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
METAL_FUNC uint get_strided_index(
uint idx,
@ -20,9 +21,130 @@ METAL_FUNC uint get_strided_index(
constant int THREADGROUP_SIZE = 2048;
# define REDUCE(FN, NAME, T) \
#define ARGMIN(NAME, T, MAXVALUE) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device uint *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MAXVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
bool notset = true; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || src[strided_i] < shared_memory[tid]) { \
shared_memory[tid] = src[strided_i]; \
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \
#define ARGMAX(NAME, T, MINVALUE) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device uint *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MINVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
bool notset = true; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || shared_memory[tid] < src[strided_i]) { \
shared_memory[tid] = src[strided_i]; \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \
#define REDUCE(FN, NAME, T, START) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
@ -34,21 +156,21 @@ kernel void NAME( \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
\
shared_memory[tid] = 0; \
shared_memory[tid] = START; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
T x = shared_memory[tid]; \
T y = src[idx]; \
T y = src[strided_i]; \
shared_memory[tid] = FN; \
idx += block_dim; \
} \
@ -71,10 +193,6 @@ kernel void NAME( \
} \
REDUCE(x + y, fast_sum_f32, float)
REDUCE(x * y, fast_mul_f32, float)
REDUCE(max(x, y), fast_max_f32, float)
#define SOFTMAX(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
@ -142,8 +260,33 @@ kernel void NAME(
} \
} \
REDUCE(x + y, fast_sum_f32_strided, float, 0)
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
REDUCE(x + y, fast_sum_f16_strided, half, 0)
REDUCE(x * y, fast_mul_f32_strided, float, 1)
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
REDUCE(x * y, fast_mul_f16_strided, half, 1)
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
ARGMAX(fast_argmax_u32_strided, uint, 0)
SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
#if __METAL_VERSION__ >= 310
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat)
#endif

View File

@ -574,12 +574,16 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
let options = MTLResourceOptions::StorageModeManaged;
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
call_reduce_contiguous(
let num_dims = 1;
let dims = vec![v.len()];
let strides = vec![1];
call_reduce_strided(
&device,
command_buffer,
&kernels,
name,
v.len(),
&dims,
&strides,
out_length,
&input,
0,
@ -623,7 +627,7 @@ fn reduce_sum() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 1;
let results = run_reduce(&v, out_length, "fast_sum_f32");
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
assert_eq!(approx(results, 4), vec![21.0]);
}
@ -632,7 +636,7 @@ fn reduce_sum2() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 2;
let results = run_reduce(&v, out_length, "fast_sum_f32");
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
}