Metal: i64 basic support (#1495)

* Adds basic metal i64 support

* metal copy i64
This commit is contained in:
Gonzalo
2023-12-29 15:42:50 -03:00
committed by GitHub
parent f4a2787217
commit 4373534d59
7 changed files with 83 additions and 1 deletions

View File

@ -58,6 +58,9 @@ kernel void FN_NAME_STRIDED( \
BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \
BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided);
#define INT64_BINARY_OP(NAME, FN) \
BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
#define BFLOAT_BINARY_OP(FN, NAME) \
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
@ -65,6 +68,8 @@ BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided);
#define INT64_BINARY_OP_OUT(NAME, FN) \
BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided);
BINARY_OP(x + y, add)
BINARY_OP(x - y, sub)
@ -80,6 +85,22 @@ BINARY_OP_OUT(lt, x < y)
BINARY_OP_OUT(ge, x >= y)
BINARY_OP_OUT(gt, x > y)
#if __METAL_VERSION__ >= 220
INT64_BINARY_OP(add, x + y)
INT64_BINARY_OP(sub, x - y)
INT64_BINARY_OP(mul, x * y)
INT64_BINARY_OP(div, x / y)
INT64_BINARY_OP(min, MIN(x, y))
INT64_BINARY_OP(max, MAX(x, y))
INT64_BINARY_OP_OUT(eq, x == y)
INT64_BINARY_OP_OUT(ne, x != y)
INT64_BINARY_OP_OUT(le, x <= y)
INT64_BINARY_OP_OUT(lt, x < y)
INT64_BINARY_OP_OUT(ge, x >= y)
INT64_BINARY_OP_OUT(gt, x > y)
#endif
#if __METAL_VERSION__ >= 310
BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub)

View File

@ -52,5 +52,11 @@ CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
#if __METAL_VERSION__ >= 220
CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
#endif
#if __METAL_VERSION__ >= 310
#endif

View File

@ -130,6 +130,7 @@ macro_rules! ops{
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32"));
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64"));
}
)+
pub mod copy {
@ -137,6 +138,7 @@ macro_rules! ops{
pub const FLOAT: Kernel = Kernel("copy_f32");
pub const HALF: Kernel = Kernel("copy_f16");
pub const BFLOAT: Kernel = Kernel("copy_bf16");
pub const I64: Kernel = Kernel("copy_i64");
pub const U32: Kernel = Kernel("copy_u32");
pub const U8: Kernel = Kernel("copy_u8");
}
@ -150,6 +152,7 @@ macro_rules! ops{
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided"));
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided"));
}
)+
pub mod copy {
@ -157,6 +160,7 @@ macro_rules! ops{
pub const FLOAT: Kernel = Kernel("copy_f32_strided");
pub const HALF: Kernel = Kernel("copy_f16_strided");
pub const BFLOAT: Kernel = Kernel("copy_bf16_strided");
pub const I64: Kernel = Kernel("copy_i64_strided");
pub const U32: Kernel = Kernel("copy_u32_strided");
pub const U8: Kernel = Kernel("copy_u8_strided");
}

View File

@ -281,6 +281,15 @@ ARGMAX(fast_argmax_u32_strided, uint, 0)
SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
#if __METAL_VERSION__ >= 220
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX)
REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN)
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
#endif
#if __METAL_VERSION__ >= 310
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
REDUCE(x * y, fast_mul_bf16, bfloat, 1)

View File

@ -57,4 +57,7 @@ WHERE_OP(float, uint8_t, where_u8_f32)
// WHERE_OP(double, uint8_t, where_u8_f64)
// WHERE_OP(uint8_t, uint8_t, where_u8_u8)
// WHERE_OP(uint32_t, uint8_t, where_u8_u32)
// WHERE_OP(int64_t, uint8_t, where_u8_i64)
#if __METAL_VERSION__ >= 220
WHERE_OP(int64_t, uint8_t, where_u8_i64)
#endif

View File

@ -115,6 +115,10 @@ UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
#if __METAL_VERSION__ >= 220
UNARY(id, int64_t, copy_i64, copy_i64_strided)
#endif
#if __METAL_VERSION__ >= 310
BFLOAT_UNARY_OP(cos)
BFLOAT_UNARY_OP(sin)