mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Float -> half / bfloat conversion in unary
This commit is contained in:
@ -44,7 +44,7 @@ kernel void FN_NAME( \
|
||||
uint thread_index [[thread_index_in_threadgroup]] \
|
||||
) { \
|
||||
const uint i = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
|
||||
output[i] = FN(input[i]); \
|
||||
output[i] = TYPENAME(FN(input[i])); \
|
||||
}\
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
@ -61,7 +61,7 @@ kernel void FN_NAME_STRIDED( \
|
||||
const uint start = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
|
||||
const uint stop = min(thread_index + (threadgroup_position_in_grid * threadgroup_size), (uint) dim); \
|
||||
for (size_t i = start; i < stop; i++) { \
|
||||
output[i] = FN(input[get_strided_index(i, num_dims, dims, strides)]); \
|
||||
output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \
|
||||
output[i] = 1; \
|
||||
} \
|
||||
}
|
||||
|
Reference in New Issue
Block a user