diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index 4166d811..3d8e7f0d 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -117,7 +117,7 @@ ELU(elu_f32, float) ELU(elu_f16, half) -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) AFFINE(affine_bf16, bfloat); POWF(powf_bf16, bfloat); ELU(elu_bf16, bfloat); diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index cdc8fef8..eb560f16 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -105,7 +105,7 @@ INT64_BINARY_OP_OUT(ge, x >= y) INT64_BINARY_OP_OUT(gt, x > y) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) BFLOAT_BINARY_OP(x + y, add) BFLOAT_BINARY_OP(x - y, sub) BFLOAT_BINARY_OP(x * y, mul) diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index e9ab17b1..5aacac4a 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -58,7 +58,7 @@ CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t) CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) #endif diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 63357428..32f3f410 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -173,7 +173,7 @@ SCATTER_ADD_OP(sa_u32_f32, uint, float) SCATTER_ADD_OP(sa_u32_f16, uint, half) -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat) diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 83a56f0a..93dac662 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -295,7 +295,7 @@ ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX) ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) REDUCE(x + y, fast_sum_bf16, bfloat, 0) REDUCE(x * y, fast_mul_bf16, bfloat, 1) REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF) diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index f95f6ba9..dcf803d8 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -127,7 +127,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided) UNARY(id, int64_t, copy_i64, copy_i64_strided) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) BFLOAT_UNARY_OP(cos) BFLOAT_UNARY_OP(sin) BFLOAT_UNARY_OP(sqr)