From 1367e0278b8ac1ce5a3a27fab9dae390b26879cf Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 7 Nov 2023 10:26:59 +0100 Subject: [PATCH] pesky bfloat type --- candle-metal-kernels/src/unary.metal | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 7349ce97..81171fb2 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -37,7 +37,7 @@ kernel void FN_NAME( \ const size_t start = thread_index * length; \ const size_t stop = min(start + length, dim); \ for (size_t i = start; i < stop; i++){ \ - output[i] = FN(input[i]); \ + output[i] = TYPENAME(FN(input[i])); \ } \ }\ kernel void FN_NAME_STRIDED( \ @@ -55,7 +55,7 @@ kernel void FN_NAME_STRIDED( \ const size_t start = thread_index * length; \ const size_t stop = min(start + length, dim); \ for (size_t i = start; i < stop; i++){ \ - output[i] = FN(input[get_strided_index(i, num_dims, dims, strides, offset)]); \ + output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides, offset)])); \ } \ }