Float -> half / bfloat conversion in unary

This commit is contained in:
Ivar Flakstad
2023-11-06 17:09:39 +01:00
parent 677495f9b8
commit 4d87305c48

View File

@ -44,7 +44,7 @@ kernel void FN_NAME( \
uint thread_index [[thread_index_in_threadgroup]] \ uint thread_index [[thread_index_in_threadgroup]] \
) { \ ) { \
const uint i = thread_index + (threadgroup_position_in_grid * threadgroup_size); \ const uint i = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
output[i] = FN(input[i]); \ output[i] = TYPENAME(FN(input[i])); \
}\ }\
kernel void FN_NAME_STRIDED( \ kernel void FN_NAME_STRIDED( \
constant size_t &dim, \ constant size_t &dim, \
@ -61,7 +61,7 @@ kernel void FN_NAME_STRIDED( \
const uint start = thread_index + (threadgroup_position_in_grid * threadgroup_size); \ const uint start = thread_index + (threadgroup_position_in_grid * threadgroup_size); \
const uint stop = min(thread_index + (threadgroup_position_in_grid * threadgroup_size), (uint) dim); \ const uint stop = min(thread_index + (threadgroup_position_in_grid * threadgroup_size), (uint) dim); \
for (size_t i = start; i < stop; i++) { \ for (size_t i = start; i < stop; i++) { \
output[i] = FN(input[get_strided_index(i, num_dims, dims, strides)]); \ output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \
output[i] = 1; \ output[i] = 1; \
} \ } \
} }