mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Minor cleanups in reduce.metal. (#2004)
This commit is contained in:
@ -37,17 +37,13 @@ METAL_FUNC void argmin(
|
|||||||
threadgroup uint *shared_indices
|
threadgroup uint *shared_indices
|
||||||
) {
|
) {
|
||||||
bool notset = true;
|
bool notset = true;
|
||||||
/*
|
|
||||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||||
// to (dst_id + 1) * el_to_sum_per_block.
|
// to (dst_id + 1) * el_to_sum_per_block.
|
||||||
*/
|
|
||||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||||
size_t idx = start_idx + tid;
|
size_t idx = start_idx + tid;
|
||||||
while (idx < stop_idx) {
|
while (idx < stop_idx) {
|
||||||
/*
|
|
||||||
// TODO: Fast version for the contiguous case.
|
// TODO: Fast version for the contiguous case.
|
||||||
*/
|
|
||||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||||
if (notset || src[strided_i] < shared_memory[tid]) {
|
if (notset || src[strided_i] < shared_memory[tid]) {
|
||||||
shared_memory[tid] = src[strided_i];
|
shared_memory[tid] = src[strided_i];
|
||||||
@ -59,9 +55,7 @@ METAL_FUNC void argmin(
|
|||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
/*
|
|
||||||
// reduction in shared memory
|
// reduction in shared memory
|
||||||
*/
|
|
||||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||||
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) {
|
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) {
|
||||||
shared_indices[tid] = shared_indices[tid + s];
|
shared_indices[tid] = shared_indices[tid + s];
|
||||||
@ -69,8 +63,7 @@ METAL_FUNC void argmin(
|
|||||||
} \
|
} \
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
}
|
}
|
||||||
|
if (tid == 0) {
|
||||||
if (tid == 0){
|
|
||||||
dst[dst_id] = shared_indices[0];
|
dst[dst_id] = shared_indices[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -111,18 +104,14 @@ METAL_FUNC void argmax(
|
|||||||
threadgroup T * shared_memory,
|
threadgroup T * shared_memory,
|
||||||
threadgroup uint * shared_indices
|
threadgroup uint * shared_indices
|
||||||
) {
|
) {
|
||||||
/*
|
|
||||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||||
// to (dst_id + 1) * el_to_sum_per_block.
|
// to (dst_id + 1) * el_to_sum_per_block.
|
||||||
*/
|
|
||||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||||
size_t idx = start_idx + tid;
|
size_t idx = start_idx + tid;
|
||||||
bool notset = true;
|
bool notset = true;
|
||||||
while (idx < stop_idx) {
|
while (idx < stop_idx) {
|
||||||
/*
|
|
||||||
// TODO: Fast version for the contiguous case.
|
// TODO: Fast version for the contiguous case.
|
||||||
*/
|
|
||||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||||
if (notset || shared_memory[tid] < src[strided_i]) {
|
if (notset || shared_memory[tid] < src[strided_i]) {
|
||||||
shared_memory[tid] = src[strided_i];
|
shared_memory[tid] = src[strided_i];
|
||||||
@ -134,9 +123,7 @@ METAL_FUNC void argmax(
|
|||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
/*
|
|
||||||
// reduction in shared memory
|
// reduction in shared memory
|
||||||
*/
|
|
||||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||||
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) {
|
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) {
|
||||||
shared_indices[tid] = shared_indices[tid + s];
|
shared_indices[tid] = shared_indices[tid + s];
|
||||||
@ -145,9 +132,7 @@ METAL_FUNC void argmax(
|
|||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
// Thread 0 writes the result of the reduction
|
// Thread 0 writes the result of the reduction
|
||||||
*/
|
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
dst[dst_id] = shared_indices[0];
|
dst[dst_id] = shared_indices[0];
|
||||||
}
|
}
|
||||||
@ -188,17 +173,13 @@ METAL_FUNC void reduce(
|
|||||||
threadgroup T * shared_memory,
|
threadgroup T * shared_memory,
|
||||||
T (*fn)(T, T)
|
T (*fn)(T, T)
|
||||||
) {
|
) {
|
||||||
/*
|
|
||||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||||
// to (dst_id + 1) * el_to_sum_per_block.
|
// to (dst_id + 1) * el_to_sum_per_block.
|
||||||
*/
|
|
||||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||||
size_t idx = start_idx + tid;
|
size_t idx = start_idx + tid;
|
||||||
while (idx < stop_idx) {
|
while (idx < stop_idx) {
|
||||||
/*
|
|
||||||
// TODO: Fast version for the contiguous case.
|
// TODO: Fast version for the contiguous case.
|
||||||
*/
|
|
||||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||||
T x = shared_memory[tid];
|
T x = shared_memory[tid];
|
||||||
T y = src[strided_i];
|
T y = src[strided_i];
|
||||||
@ -208,9 +189,7 @@ METAL_FUNC void reduce(
|
|||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
/*
|
|
||||||
// reduction in shared memory
|
// reduction in shared memory
|
||||||
*/
|
|
||||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||||
if (tid < s) {
|
if (tid < s) {
|
||||||
T x = shared_memory[tid];
|
T x = shared_memory[tid];
|
||||||
@ -277,7 +256,6 @@ METAL_FUNC void softmax(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* wait for shared_memory[0] to be filled */
|
/* wait for shared_memory[0] to be filled */
|
||||||
\
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
float _max = shared_memory[0];
|
float _max = shared_memory[0];
|
||||||
|
Reference in New Issue
Block a user