diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index b8056909..715dcced 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -44,7 +44,7 @@ kernel void FN_NAME( \ uint thread_index [[thread_index_in_threadgroup]] \ ) { \ const uint i = thread_index + (threadgroup_position_in_grid * threadgroup_size); \ - output[i] = FN(input[i]); \ + output[i] = TYPENAME(FN(input[i])); \ }\ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -61,7 +61,7 @@ kernel void FN_NAME_STRIDED( \ const uint start = thread_index + (threadgroup_position_in_grid * threadgroup_size); \ const uint stop = min(thread_index + (threadgroup_position_in_grid * threadgroup_size), (uint) dim); \ for (size_t i = start; i < stop; i++) { \ - output[i] = FN(input[get_strided_index(i, num_dims, dims, strides)]); \ + output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \ output[i] = 1; \ } \ }