Minor cleanups in reduce.metal. (#2004)

This commit is contained in:
Laurent Mazare
2024-04-04 08:26:02 +02:00
committed by GitHub
parent bd8db2a771
commit 1e46cf8b19

View File

@ -37,17 +37,13 @@ METAL_FUNC void argmin(
threadgroup uint *shared_indices
) {
bool notset = true;
/*
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
*/
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = start_idx + el_to_sum_per_block;
size_t idx = start_idx + tid;
while (idx < stop_idx) {
/*
// TODO: Fast version for the contiguous case.
*/
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
if (notset || src[strided_i] < shared_memory[tid]) {
shared_memory[tid] = src[strided_i];
@ -59,9 +55,7 @@ METAL_FUNC void argmin(
}
threadgroup_barrier(mem_flags::mem_none);
/*
// reduction in shared memory
*/
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) {
shared_indices[tid] = shared_indices[tid + s];
@ -69,8 +63,7 @@ METAL_FUNC void argmin(
} \
threadgroup_barrier(mem_flags::mem_none);
}
if (tid == 0){
if (tid == 0) {
dst[dst_id] = shared_indices[0];
}
}
@ -111,18 +104,14 @@ METAL_FUNC void argmax(
threadgroup T * shared_memory,
threadgroup uint * shared_indices
) {
/*
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
*/
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = start_idx + el_to_sum_per_block;
size_t idx = start_idx + tid;
bool notset = true;
while (idx < stop_idx) {
/*
// TODO: Fast version for the contiguous case.
*/
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
if (notset || shared_memory[tid] < src[strided_i]) {
shared_memory[tid] = src[strided_i];
@ -134,9 +123,7 @@ METAL_FUNC void argmax(
threadgroup_barrier(mem_flags::mem_none);
/*
// reduction in shared memory
*/
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) {
shared_indices[tid] = shared_indices[tid + s];
@ -145,9 +132,7 @@ METAL_FUNC void argmax(
threadgroup_barrier(mem_flags::mem_none);
}
/*
// Thread 0 writes the result of the reduction
*/
if (tid == 0) {
dst[dst_id] = shared_indices[0];
}
@ -188,17 +173,13 @@ METAL_FUNC void reduce(
threadgroup T * shared_memory,
T (*fn)(T, T)
) {
/*
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
*/
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = start_idx + el_to_sum_per_block;
size_t idx = start_idx + tid;
while (idx < stop_idx) {
/*
// TODO: Fast version for the contiguous case.
*/
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
T x = shared_memory[tid];
T y = src[strided_i];
@ -208,9 +189,7 @@ METAL_FUNC void reduce(
threadgroup_barrier(mem_flags::mem_none);
/*
// reduction in shared memory
*/
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s) {
T x = shared_memory[tid];
@ -277,7 +256,6 @@ METAL_FUNC void softmax(
}
/* wait for shared_memory[0] to be filled */
\
threadgroup_barrier(mem_flags::mem_threadgroup);
float _max = shared_memory[0];