mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Fixing the kernels + launches to make them faster.
Cool work by @ivarflakstad Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:

committed by
Nicolas Patry

parent
f82bf2d915
commit
f710fab02e
@ -23,17 +23,14 @@ kernel void FN_NAME( \
|
||||
device const TYPENAME *left, \
|
||||
device const TYPENAME *right, \
|
||||
device TYPENAME *output, \
|
||||
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||
uint thread_index [[thread_index_in_threadgroup]] \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
|
||||
const size_t start = thread_index * length; \
|
||||
const size_t stop = min(start + length, dim); \
|
||||
for (size_t i = start; i < stop; i++){ \
|
||||
TYPENAME x = left[i]; \
|
||||
TYPENAME y = right[i]; \
|
||||
output[i] = OUT_TYPENAME(FN); \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
TYPENAME x = left[thread_position_in_grid]; \
|
||||
TYPENAME y = right[thread_position_in_grid]; \
|
||||
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
|
||||
}\
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
@ -44,17 +41,14 @@ kernel void FN_NAME_STRIDED( \
|
||||
device const TYPENAME *left, \
|
||||
device const TYPENAME *right, \
|
||||
device TYPENAME *output, \
|
||||
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||
uint thread_index [[thread_index_in_threadgroup]] \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
|
||||
const size_t start = thread_index * length; \
|
||||
const size_t stop = min(start + length, dim); \
|
||||
for (size_t i = start; i < stop; i++){ \
|
||||
TYPENAME x = left[get_strided_index(i, num_dims, dims, left_strides)]; \
|
||||
TYPENAME y = left[get_strided_index(i, num_dims, dims, right_strides)]; \
|
||||
output[i] = OUT_TYPENAME(FN); \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
|
||||
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
|
||||
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
|
||||
}
|
||||
|
||||
#define BINARY_OP(FN, NAME) \
|
||||
|
Reference in New Issue
Block a user