Minor cleanups in reduce.metal. (#2004)

This commit is contained in:
Laurent Mazare
2024-04-04 08:26:02 +02:00
committed by GitHub
parent bd8db2a771
commit 1e46cf8b19

View File

@ -37,17 +37,13 @@ METAL_FUNC void argmin(
threadgroup uint *shared_indices threadgroup uint *shared_indices
) { ) {
bool notset = true; bool notset = true;
/*
// Elements summed in this block range from dst_id * el_to_sum_per_block // Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block. // to (dst_id + 1) * el_to_sum_per_block.
*/
size_t start_idx = dst_id * el_to_sum_per_block; size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = start_idx + el_to_sum_per_block; size_t stop_idx = start_idx + el_to_sum_per_block;
size_t idx = start_idx + tid; size_t idx = start_idx + tid;
while (idx < stop_idx) { while (idx < stop_idx) {
/*
// TODO: Fast version for the contiguous case. // TODO: Fast version for the contiguous case.
*/
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
if (notset || src[strided_i] < shared_memory[tid]) { if (notset || src[strided_i] < shared_memory[tid]) {
shared_memory[tid] = src[strided_i]; shared_memory[tid] = src[strided_i];
@ -59,9 +55,7 @@ METAL_FUNC void argmin(
} }
threadgroup_barrier(mem_flags::mem_none); threadgroup_barrier(mem_flags::mem_none);
/*
// reduction in shared memory // reduction in shared memory
*/
for (uint s = block_dim / 2; s > 0; s >>= 1) { for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { if (tid < s && shared_memory[tid + s] < shared_memory[tid]) {
shared_indices[tid] = shared_indices[tid + s]; shared_indices[tid] = shared_indices[tid + s];
@ -69,8 +63,7 @@ METAL_FUNC void argmin(
} \ } \
threadgroup_barrier(mem_flags::mem_none); threadgroup_barrier(mem_flags::mem_none);
} }
if (tid == 0) {
if (tid == 0){
dst[dst_id] = shared_indices[0]; dst[dst_id] = shared_indices[0];
} }
} }
@ -111,18 +104,14 @@ METAL_FUNC void argmax(
threadgroup T * shared_memory, threadgroup T * shared_memory,
threadgroup uint * shared_indices threadgroup uint * shared_indices
) { ) {
/*
// Elements summed in this block range from dst_id * el_to_sum_per_block // Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block. // to (dst_id + 1) * el_to_sum_per_block.
*/
size_t start_idx = dst_id * el_to_sum_per_block; size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = start_idx + el_to_sum_per_block; size_t stop_idx = start_idx + el_to_sum_per_block;
size_t idx = start_idx + tid; size_t idx = start_idx + tid;
bool notset = true; bool notset = true;
while (idx < stop_idx) { while (idx < stop_idx) {
/*
// TODO: Fast version for the contiguous case. // TODO: Fast version for the contiguous case.
*/
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
if (notset || shared_memory[tid] < src[strided_i]) { if (notset || shared_memory[tid] < src[strided_i]) {
shared_memory[tid] = src[strided_i]; shared_memory[tid] = src[strided_i];
@ -134,9 +123,7 @@ METAL_FUNC void argmax(
threadgroup_barrier(mem_flags::mem_none); threadgroup_barrier(mem_flags::mem_none);
/*
// reduction in shared memory // reduction in shared memory
*/
for (uint s = block_dim / 2; s > 0; s >>= 1) { for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { if (tid < s && shared_memory[tid + s] > shared_memory[tid]) {
shared_indices[tid] = shared_indices[tid + s]; shared_indices[tid] = shared_indices[tid + s];
@ -145,9 +132,7 @@ METAL_FUNC void argmax(
threadgroup_barrier(mem_flags::mem_none); threadgroup_barrier(mem_flags::mem_none);
} }
/*
// Thread 0 writes the result of the reduction // Thread 0 writes the result of the reduction
*/
if (tid == 0) { if (tid == 0) {
dst[dst_id] = shared_indices[0]; dst[dst_id] = shared_indices[0];
} }
@ -188,17 +173,13 @@ METAL_FUNC void reduce(
threadgroup T * shared_memory, threadgroup T * shared_memory,
T (*fn)(T, T) T (*fn)(T, T)
) { ) {
/*
// Elements summed in this block range from dst_id * el_to_sum_per_block // Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block. // to (dst_id + 1) * el_to_sum_per_block.
*/
size_t start_idx = dst_id * el_to_sum_per_block; size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = start_idx + el_to_sum_per_block; size_t stop_idx = start_idx + el_to_sum_per_block;
size_t idx = start_idx + tid; size_t idx = start_idx + tid;
while (idx < stop_idx) { while (idx < stop_idx) {
/*
// TODO: Fast version for the contiguous case. // TODO: Fast version for the contiguous case.
*/
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
T x = shared_memory[tid]; T x = shared_memory[tid];
T y = src[strided_i]; T y = src[strided_i];
@ -208,9 +189,7 @@ METAL_FUNC void reduce(
threadgroup_barrier(mem_flags::mem_none); threadgroup_barrier(mem_flags::mem_none);
/*
// reduction in shared memory // reduction in shared memory
*/
for (uint s = block_dim / 2; s > 0; s >>= 1) { for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s) { if (tid < s) {
T x = shared_memory[tid]; T x = shared_memory[tid];
@ -277,7 +256,6 @@ METAL_FUNC void softmax(
} }
/* wait for shared_memory[0] to be filled */ /* wait for shared_memory[0] to be filled */
\
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
float _max = shared_memory[0]; float _max = shared_memory[0];