From 4373534d59d3a6357aef0b3f35a247f695f4700a Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 29 Dec 2023 15:42:50 -0300 Subject: [PATCH] Metal: i64 basic support (#1495) * Adds basic metal i64 support * metal copy i64 --- candle-core/src/metal_backend.rs | 35 ++++++++++++++++++++++++++ candle-metal-kernels/src/binary.metal | 21 ++++++++++++++++ candle-metal-kernels/src/cast.metal | 6 +++++ candle-metal-kernels/src/lib.rs | 4 +++ candle-metal-kernels/src/reduce.metal | 9 +++++++ candle-metal-kernels/src/ternary.metal | 5 +++- candle-metal-kernels/src/unary.metal | 4 +++ 7 files changed, 83 insertions(+), 1 deletion(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 9abbda6e..76577992 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -532,6 +532,11 @@ impl BackendStorage for MetalStorage { (ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false), (ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true), (ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true), + (ReduceOp::Sum, DType::I64) => ("fast_sum_i64_strided", false, false), + (ReduceOp::Min, DType::I64) => ("fast_min_i64_strided", true, false), + (ReduceOp::Max, DType::I64) => ("fast_max_i64_strided", true, false), + (ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64_strided", true, true), + (ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64_strided", true, true), (k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"), }; if check_empty && layout.shape().elem_count() == 0 { @@ -579,10 +584,13 @@ impl BackendStorage for MetalStorage { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32", (DType::U32, DType::U8) => "cast_u32_u8", + (DType::U32, DType::I64) => "cast_u32_i64", (DType::U8, DType::U32) => "cast_u8_u32", (DType::U8, DType::F32) => "cast_u8_f32", + (DType::U8, DType::I64) => "cast_u8_i64", (DType::F32, DType::F16) => "cast_f32_f16", (DType::F16, DType::F32) => "cast_f16_f32", + (DType::I64, DType::F32) => "cast_i64_f32", (left, right) => { crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented") } @@ -602,10 +610,13 @@ impl BackendStorage for MetalStorage { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32_strided", (DType::U32, DType::U8) => "cast_u32_u8_strided", + (DType::U32, DType::I64) => "cast_u32_i64_strided", (DType::U8, DType::U32) => "cast_u8_u32_strided", (DType::U8, DType::F32) => "cast_u8_f32_strided", + (DType::U8, DType::I64) => "cast_u8_i64_strided", (DType::F32, DType::F16) => "cast_f32_f16_strided", (DType::F16, DType::F32) => "cast_f16_f32_strided", + (DType::I64, DType::F32) => "cast_i64_f32_strided", (left, right) => { crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented") } @@ -767,6 +778,7 @@ impl BackendStorage for MetalStorage { let name = match (self.dtype, t.dtype()) { (DType::U8, DType::F32) => "where_u8_f32", (DType::U8, DType::F16) => "where_u8_f16", + (DType::U8, DType::I64) => "where_u8_i64", (left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"), }; candle_metal_kernels::call_where_cond_strided( @@ -1226,6 +1238,7 @@ impl BackendStorage for MetalStorage { DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, + DType::I64 => candle_metal_kernels::unary::strided::copy::I64, DType::U32 => candle_metal_kernels::unary::strided::copy::U32, DType::U8 => candle_metal_kernels::unary::strided::copy::U8, dtype => crate::bail!("Metal copy_strided {dtype:?} not implemented"), @@ -1300,6 +1313,16 @@ impl MetalStorage { ("lt", DType::F16) => (contiguous::lt::HALF, DType::U8), ("ge", DType::F16) => (contiguous::ge::HALF, DType::U8), ("gt", DType::F16) => (contiguous::gt::HALF, DType::U8), + ("add", DType::I64) => (contiguous::add::I64, self.dtype), + ("sub", DType::I64) => (contiguous::sub::I64, self.dtype), + ("mul", DType::I64) => (contiguous::mul::I64, self.dtype), + ("div", DType::I64) => (contiguous::div::I64, self.dtype), + ("eq", DType::I64) => (contiguous::eq::I64, DType::U8), + ("ne", DType::I64) => (contiguous::ne::I64, DType::U8), + ("le", DType::I64) => (contiguous::le::I64, DType::U8), + ("lt", DType::I64) => (contiguous::lt::I64, DType::U8), + ("ge", DType::I64) => (contiguous::ge::I64, DType::U8), + ("gt", DType::I64) => (contiguous::gt::I64, DType::U8), (name, dtype) => { crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented") } @@ -1345,6 +1368,18 @@ impl MetalStorage { ("lt", DType::F16) => (strided::lt::HALF, DType::U8), ("ge", DType::F16) => (strided::ge::HALF, DType::U8), ("gt", DType::F16) => (strided::gt::HALF, DType::U8), + ("badd", DType::I64) => (strided::add::I64, self.dtype), + ("bsub", DType::I64) => (strided::sub::I64, self.dtype), + ("bmul", DType::I64) => (strided::mul::I64, self.dtype), + ("bdiv", DType::I64) => (strided::div::I64, self.dtype), + ("bminimum", DType::I64) => (strided::min::I64, self.dtype), + ("bmaximum", DType::I64) => (strided::max::I64, self.dtype), + ("eq", DType::I64) => (strided::eq::I64, DType::U8), + ("ne", DType::I64) => (strided::ne::I64, DType::U8), + ("le", DType::I64) => (strided::le::I64, DType::U8), + ("lt", DType::I64) => (strided::lt::I64, DType::U8), + ("ge", DType::I64) => (strided::ge::I64, DType::U8), + ("gt", DType::I64) => (strided::gt::I64, DType::U8), (name, dtype) => { crate::bail!("Metal strided binary {name} {dtype:?} not implemented") } diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index 8c3b4a8c..30c90ff1 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -58,6 +58,9 @@ kernel void FN_NAME_STRIDED( \ BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \ BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); +#define INT64_BINARY_OP(NAME, FN) \ +BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided); + #define BFLOAT_BINARY_OP(FN, NAME) \ BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided); @@ -65,6 +68,8 @@ BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided); BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \ BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); +#define INT64_BINARY_OP_OUT(NAME, FN) \ +BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided); BINARY_OP(x + y, add) BINARY_OP(x - y, sub) @@ -80,6 +85,22 @@ BINARY_OP_OUT(lt, x < y) BINARY_OP_OUT(ge, x >= y) BINARY_OP_OUT(gt, x > y) +#if __METAL_VERSION__ >= 220 +INT64_BINARY_OP(add, x + y) +INT64_BINARY_OP(sub, x - y) +INT64_BINARY_OP(mul, x * y) +INT64_BINARY_OP(div, x / y) +INT64_BINARY_OP(min, MIN(x, y)) +INT64_BINARY_OP(max, MAX(x, y)) + +INT64_BINARY_OP_OUT(eq, x == y) +INT64_BINARY_OP_OUT(ne, x != y) +INT64_BINARY_OP_OUT(le, x <= y) +INT64_BINARY_OP_OUT(lt, x < y) +INT64_BINARY_OP_OUT(ge, x >= y) +INT64_BINARY_OP_OUT(gt, x > y) +#endif + #if __METAL_VERSION__ >= 310 BFLOAT_BINARY_OP(x + y, add) BFLOAT_BINARY_OP(x - y, sub) diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index 8481389d..3baefcc2 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -52,5 +52,11 @@ CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float) CAST(cast_f16_f32, cast_f16_f32_strided, half, float) CAST(cast_f32_f16, cast_f32_f16_strided, float, half) +#if __METAL_VERSION__ >= 220 +CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t) +CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t) +CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) +#endif + #if __METAL_VERSION__ >= 310 #endif diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 94479882..7b0084d9 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -130,6 +130,7 @@ macro_rules! ops{ pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32")); pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16")); + pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64")); } )+ pub mod copy { @@ -137,6 +138,7 @@ macro_rules! ops{ pub const FLOAT: Kernel = Kernel("copy_f32"); pub const HALF: Kernel = Kernel("copy_f16"); pub const BFLOAT: Kernel = Kernel("copy_bf16"); + pub const I64: Kernel = Kernel("copy_i64"); pub const U32: Kernel = Kernel("copy_u32"); pub const U8: Kernel = Kernel("copy_u8"); } @@ -150,6 +152,7 @@ macro_rules! ops{ pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided")); pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided")); + pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided")); } )+ pub mod copy { @@ -157,6 +160,7 @@ macro_rules! ops{ pub const FLOAT: Kernel = Kernel("copy_f32_strided"); pub const HALF: Kernel = Kernel("copy_f16_strided"); pub const BFLOAT: Kernel = Kernel("copy_bf16_strided"); + pub const I64: Kernel = Kernel("copy_i64_strided"); pub const U32: Kernel = Kernel("copy_u32_strided"); pub const U8: Kernel = Kernel("copy_u8_strided"); } diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 2d584917..38252967 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -281,6 +281,15 @@ ARGMAX(fast_argmax_u32_strided, uint, 0) SOFTMAX(softmax_f32, float) SOFTMAX(softmax_f16, half) + +#if __METAL_VERSION__ >= 220 +REDUCE(x + y, fast_sum_i64_strided, int64_t, 0) +REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX) +REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN) +ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX) +ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN) +#endif + #if __METAL_VERSION__ >= 310 REDUCE(x + y, fast_sum_bf16, bfloat, 0) REDUCE(x * y, fast_mul_bf16, bfloat, 1) diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal index 1f9cb38a..dfa0dd12 100644 --- a/candle-metal-kernels/src/ternary.metal +++ b/candle-metal-kernels/src/ternary.metal @@ -57,4 +57,7 @@ WHERE_OP(float, uint8_t, where_u8_f32) // WHERE_OP(double, uint8_t, where_u8_f64) // WHERE_OP(uint8_t, uint8_t, where_u8_u8) // WHERE_OP(uint32_t, uint8_t, where_u8_u32) -// WHERE_OP(int64_t, uint8_t, where_u8_i64) + +#if __METAL_VERSION__ >= 220 +WHERE_OP(int64_t, uint8_t, where_u8_i64) +#endif diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 826b9045..15d1e400 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -115,6 +115,10 @@ UNARY(id, half, copy_f16, copy_f16_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided) UNARY(id, uint32_t, copy_u32, copy_u32_strided) +#if __METAL_VERSION__ >= 220 +UNARY(id, int64_t, copy_i64, copy_i64_strided) +#endif + #if __METAL_VERSION__ >= 310 BFLOAT_UNARY_OP(cos) BFLOAT_UNARY_OP(sin)