mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Metal: i64 basic support (#1495)
* Adds basic metal i64 support * metal copy i64
This commit is contained in:
@ -58,6 +58,9 @@ kernel void FN_NAME_STRIDED( \
|
||||
BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \
|
||||
BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided);
|
||||
|
||||
#define INT64_BINARY_OP(NAME, FN) \
|
||||
BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
|
||||
|
||||
#define BFLOAT_BINARY_OP(FN, NAME) \
|
||||
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||
|
||||
@ -65,6 +68,8 @@ BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
|
||||
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided);
|
||||
|
||||
#define INT64_BINARY_OP_OUT(NAME, FN) \
|
||||
BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided);
|
||||
|
||||
BINARY_OP(x + y, add)
|
||||
BINARY_OP(x - y, sub)
|
||||
@ -80,6 +85,22 @@ BINARY_OP_OUT(lt, x < y)
|
||||
BINARY_OP_OUT(ge, x >= y)
|
||||
BINARY_OP_OUT(gt, x > y)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
INT64_BINARY_OP(add, x + y)
|
||||
INT64_BINARY_OP(sub, x - y)
|
||||
INT64_BINARY_OP(mul, x * y)
|
||||
INT64_BINARY_OP(div, x / y)
|
||||
INT64_BINARY_OP(min, MIN(x, y))
|
||||
INT64_BINARY_OP(max, MAX(x, y))
|
||||
|
||||
INT64_BINARY_OP_OUT(eq, x == y)
|
||||
INT64_BINARY_OP_OUT(ne, x != y)
|
||||
INT64_BINARY_OP_OUT(le, x <= y)
|
||||
INT64_BINARY_OP_OUT(lt, x < y)
|
||||
INT64_BINARY_OP_OUT(ge, x >= y)
|
||||
INT64_BINARY_OP_OUT(gt, x > y)
|
||||
#endif
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
BFLOAT_BINARY_OP(x + y, add)
|
||||
BFLOAT_BINARY_OP(x - y, sub)
|
||||
|
Reference in New Issue
Block a user