mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Minor cleanups in reduce.metal. (#2004)
This commit is contained in:
@ -37,17 +37,13 @@ METAL_FUNC void argmin(
|
||||
threadgroup uint *shared_indices
|
||||
) {
|
||||
bool notset = true;
|
||||
/*
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
*/
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||
size_t idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
/*
|
||||
// TODO: Fast version for the contiguous case.
|
||||
*/
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||
if (notset || src[strided_i] < shared_memory[tid]) {
|
||||
shared_memory[tid] = src[strided_i];
|
||||
@ -59,9 +55,7 @@ METAL_FUNC void argmin(
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
/*
|
||||
// reduction in shared memory
|
||||
*/
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) {
|
||||
shared_indices[tid] = shared_indices[tid + s];
|
||||
@ -69,8 +63,7 @@ METAL_FUNC void argmin(
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
if (tid == 0){
|
||||
if (tid == 0) {
|
||||
dst[dst_id] = shared_indices[0];
|
||||
}
|
||||
}
|
||||
@ -111,18 +104,14 @@ METAL_FUNC void argmax(
|
||||
threadgroup T * shared_memory,
|
||||
threadgroup uint * shared_indices
|
||||
) {
|
||||
/*
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
*/
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||
size_t idx = start_idx + tid;
|
||||
bool notset = true;
|
||||
while (idx < stop_idx) {
|
||||
/*
|
||||
// TODO: Fast version for the contiguous case.
|
||||
*/
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||
if (notset || shared_memory[tid] < src[strided_i]) {
|
||||
shared_memory[tid] = src[strided_i];
|
||||
@ -134,9 +123,7 @@ METAL_FUNC void argmax(
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
/*
|
||||
// reduction in shared memory
|
||||
*/
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) {
|
||||
shared_indices[tid] = shared_indices[tid + s];
|
||||
@ -145,9 +132,7 @@ METAL_FUNC void argmax(
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
/*
|
||||
// Thread 0 writes the result of the reduction
|
||||
*/
|
||||
if (tid == 0) {
|
||||
dst[dst_id] = shared_indices[0];
|
||||
}
|
||||
@ -188,17 +173,13 @@ METAL_FUNC void reduce(
|
||||
threadgroup T * shared_memory,
|
||||
T (*fn)(T, T)
|
||||
) {
|
||||
/*
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
*/
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||
size_t idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
/*
|
||||
// TODO: Fast version for the contiguous case.
|
||||
*/
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||
T x = shared_memory[tid];
|
||||
T y = src[strided_i];
|
||||
@ -208,9 +189,7 @@ METAL_FUNC void reduce(
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
/*
|
||||
// reduction in shared memory
|
||||
*/
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
T x = shared_memory[tid];
|
||||
@ -277,7 +256,6 @@ METAL_FUNC void softmax(
|
||||
}
|
||||
|
||||
/* wait for shared_memory[0] to be filled */
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float _max = shared_memory[0];
|
||||
|
Reference in New Issue
Block a user