From 4d87305c48638e6cefd9669101d175acf1baf43a Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 6 Nov 2023 17:09:39 +0100 Subject: [PATCH] Float -> half / bfloat conversion in unary --- candle-metal-kernels/src/unary.metal | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index b8056909..715dcced 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -44,7 +44,7 @@ kernel void FN_NAME( \ uint thread_index [[thread_index_in_threadgroup]] \ ) { \ const uint i = thread_index + (threadgroup_position_in_grid * threadgroup_size); \ - output[i] = FN(input[i]); \ + output[i] = TYPENAME(FN(input[i])); \ }\ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -61,7 +61,7 @@ kernel void FN_NAME_STRIDED( \ const uint start = thread_index + (threadgroup_position_in_grid * threadgroup_size); \ const uint stop = min(thread_index + (threadgroup_position_in_grid * threadgroup_size), (uint) dim); \ for (size_t i = start; i < stop; i++) { \ - output[i] = FN(input[get_strided_index(i, num_dims, dims, strides)]); \ + output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \ output[i] = 1; \ } \ }