From d3bdd788cfdcf49b6ea539b77647b82a0b979db0 Mon Sep 17 00:00:00 2001 From: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Wed, 10 Jan 2024 18:50:30 +0100 Subject: [PATCH] Use __HAVE_BFLOAT__ to check for bfloat support instead of metal version check (#1540) --- candle-metal-kernels/src/affine.metal | 2 +- candle-metal-kernels/src/binary.metal | 2 +- candle-metal-kernels/src/cast.metal | 2 +- candle-metal-kernels/src/indexing.metal | 2 +- candle-metal-kernels/src/reduce.metal | 2 +- candle-metal-kernels/src/unary.metal | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index 4166d811..3d8e7f0d 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -117,7 +117,7 @@ ELU(elu_f32, float) ELU(elu_f16, half) -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) AFFINE(affine_bf16, bfloat); POWF(powf_bf16, bfloat); ELU(elu_bf16, bfloat); diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index cdc8fef8..eb560f16 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -105,7 +105,7 @@ INT64_BINARY_OP_OUT(ge, x >= y) INT64_BINARY_OP_OUT(gt, x > y) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) BFLOAT_BINARY_OP(x + y, add) BFLOAT_BINARY_OP(x - y, sub) BFLOAT_BINARY_OP(x * y, mul) diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index e9ab17b1..5aacac4a 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -58,7 +58,7 @@ CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t) CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) #endif diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 63357428..32f3f410 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -173,7 +173,7 @@ SCATTER_ADD_OP(sa_u32_f32, uint, float) SCATTER_ADD_OP(sa_u32_f16, uint, half) -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat) diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 83a56f0a..93dac662 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -295,7 +295,7 @@ ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX) ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) REDUCE(x + y, fast_sum_bf16, bfloat, 0) REDUCE(x * y, fast_mul_bf16, bfloat, 1) REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF) diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index f95f6ba9..dcf803d8 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -127,7 +127,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided) UNARY(id, int64_t, copy_i64, copy_i64_strided) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) BFLOAT_UNARY_OP(cos) BFLOAT_UNARY_OP(sin) BFLOAT_UNARY_OP(sqr)