update dtypes checks for several metal operations (#2010)

This commit is contained in:
Thomas Santerre
2024-04-04 12:39:06 -04:00
committed by GitHub
parent f76bb7794a
commit 5aebe53dd2
3 changed files with 69 additions and 37 deletions

View File

@ -60,21 +60,24 @@ BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \
BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
#define INT64_BINARY_OP(NAME, FN) \
BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
#define BFLOAT_BINARY_OP(FN, NAME) \
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
#define BINARY_OP_OUT(NAME, FN) \
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \
BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
#define INT64_BINARY_OP(NAME, FN) \
BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
#define INT64_BINARY_OP_OUT(NAME, FN) \
BINARY(FN, int64_t, uint8_t, NAME##_i64, NAME##_i64_strided);
#define BFLOAT_BINARY_OP(FN, NAME) \
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
#define BFLOAT_BINARY_OP_OUT(NAME, FN) \
BINARY(FN, bfloat, uint8_t, NAME##_bf16, NAME##_bf16_strided);
BINARY_OP(x + y, add)
BINARY_OP(x - y, sub)
BINARY_OP(x * y, mul)
@ -112,4 +115,11 @@ BFLOAT_BINARY_OP(x * y, mul)
BFLOAT_BINARY_OP(x / y, div)
BFLOAT_BINARY_OP(MIN(x, y), min)
BFLOAT_BINARY_OP(MAX(x, y), max)
BFLOAT_BINARY_OP_OUT(eq, x == y)
BFLOAT_BINARY_OP_OUT(ne, x != y)
BFLOAT_BINARY_OP_OUT(le, x <= y)
BFLOAT_BINARY_OP_OUT(lt, x < y)
BFLOAT_BINARY_OP_OUT(ge, x >= y)
BFLOAT_BINARY_OP_OUT(gt, x > y)
#endif