mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Use __HAVE_BFLOAT__ to check for bfloat support instead of metal version check (#1540)
This commit is contained in:
@ -117,7 +117,7 @@ ELU(elu_f32, float)
|
||||
ELU(elu_f16, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
AFFINE(affine_bf16, bfloat);
|
||||
POWF(powf_bf16, bfloat);
|
||||
ELU(elu_bf16, bfloat);
|
||||
|
@ -105,7 +105,7 @@ INT64_BINARY_OP_OUT(ge, x >= y)
|
||||
INT64_BINARY_OP_OUT(gt, x > y)
|
||||
#endif
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
BFLOAT_BINARY_OP(x + y, add)
|
||||
BFLOAT_BINARY_OP(x - y, sub)
|
||||
BFLOAT_BINARY_OP(x * y, mul)
|
||||
|
@ -58,7 +58,7 @@ CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
|
||||
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
|
||||
#endif
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
|
||||
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
||||
#endif
|
||||
|
@ -173,7 +173,7 @@ SCATTER_ADD_OP(sa_u32_f32, uint, float)
|
||||
SCATTER_ADD_OP(sa_u32_f16, uint, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
|
||||
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
|
||||
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
|
||||
|
@ -295,7 +295,7 @@ ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
|
||||
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
|
||||
#endif
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
||||
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
||||
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
|
||||
|
@ -127,7 +127,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||
UNARY(id, int64_t, copy_i64, copy_i64_strided)
|
||||
#endif
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
BFLOAT_UNARY_OP(cos)
|
||||
BFLOAT_UNARY_OP(sin)
|
||||
BFLOAT_UNARY_OP(sqr)
|
||||
|
Reference in New Issue
Block a user