diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 76577992..7a22595e 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -537,6 +537,11 @@ impl BackendStorage for MetalStorage { (ReduceOp::Max, DType::I64) => ("fast_max_i64_strided", true, false), (ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64_strided", true, true), (ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64_strided", true, true), + (ReduceOp::Sum, DType::U8) => ("fast_sum_u8_strided", false, false), + (ReduceOp::Min, DType::U8) => ("fast_min_u8_strided", true, false), + (ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false), + (ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true), + (ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true), (k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"), }; if check_empty && layout.shape().elem_count() == 0 { @@ -779,6 +784,8 @@ impl BackendStorage for MetalStorage { (DType::U8, DType::F32) => "where_u8_f32", (DType::U8, DType::F16) => "where_u8_f16", (DType::U8, DType::I64) => "where_u8_i64", + (DType::U8, DType::U32) => "where_u8_u32", + (DType::U8, DType::U8) => "where_u8_u8", (left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"), }; candle_metal_kernels::call_where_cond_strided( @@ -1323,6 +1330,26 @@ impl MetalStorage { ("lt", DType::I64) => (contiguous::lt::I64, DType::U8), ("ge", DType::I64) => (contiguous::ge::I64, DType::U8), ("gt", DType::I64) => (contiguous::gt::I64, DType::U8), + ("add", DType::U32) => (contiguous::add::U32, self.dtype), + ("sub", DType::U32) => (contiguous::sub::U32, self.dtype), + ("mul", DType::U32) => (contiguous::mul::U32, self.dtype), + ("div", DType::U32) => (contiguous::div::U32, self.dtype), + ("eq", DType::U32) => (contiguous::eq::U32, DType::U8), + ("ne", DType::U32) => (contiguous::ne::U32, DType::U8), + ("le", DType::U32) => (contiguous::le::U32, DType::U8), + ("lt", DType::U32) => (contiguous::lt::U32, DType::U8), + ("ge", DType::U32) => (contiguous::ge::U32, DType::U8), + ("gt", DType::U32) => (contiguous::gt::U32, DType::U8), + ("add", DType::U8) => (contiguous::add::U8, self.dtype), + ("sub", DType::U8) => (contiguous::sub::U8, self.dtype), + ("mul", DType::U8) => (contiguous::mul::U8, self.dtype), + ("div", DType::U8) => (contiguous::div::U8, self.dtype), + ("eq", DType::U8) => (contiguous::eq::U8, DType::U8), + ("ne", DType::U8) => (contiguous::ne::U8, DType::U8), + ("le", DType::U8) => (contiguous::le::U8, DType::U8), + ("lt", DType::U8) => (contiguous::lt::U8, DType::U8), + ("ge", DType::U8) => (contiguous::ge::U8, DType::U8), + ("gt", DType::U8) => (contiguous::gt::U8, DType::U8), (name, dtype) => { crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented") } @@ -1380,6 +1407,30 @@ impl MetalStorage { ("lt", DType::I64) => (strided::lt::I64, DType::U8), ("ge", DType::I64) => (strided::ge::I64, DType::U8), ("gt", DType::I64) => (strided::gt::I64, DType::U8), + ("badd", DType::U32) => (strided::add::U32, self.dtype), + ("bsub", DType::U32) => (strided::sub::U32, self.dtype), + ("bmul", DType::U32) => (strided::mul::U32, self.dtype), + ("bdiv", DType::U32) => (strided::div::U32, self.dtype), + ("bminimum", DType::U32) => (strided::min::U32, self.dtype), + ("bmaximum", DType::U32) => (strided::max::U32, self.dtype), + ("eq", DType::U32) => (strided::eq::U32, DType::U8), + ("ne", DType::U32) => (strided::ne::U32, DType::U8), + ("le", DType::U32) => (strided::le::U32, DType::U8), + ("lt", DType::U32) => (strided::lt::U32, DType::U8), + ("ge", DType::U32) => (strided::ge::U32, DType::U8), + ("gt", DType::U32) => (strided::gt::U32, DType::U8), + ("badd", DType::U8) => (strided::add::U8, self.dtype), + ("bsub", DType::U8) => (strided::sub::U8, self.dtype), + ("bmul", DType::U8) => (strided::mul::U8, self.dtype), + ("bdiv", DType::U8) => (strided::div::U8, self.dtype), + ("bminimum", DType::U8) => (strided::min::U8, self.dtype), + ("bmaximum", DType::U8) => (strided::max::U8, self.dtype), + ("eq", DType::U8) => (strided::eq::U8, DType::U8), + ("ne", DType::U8) => (strided::ne::U8, DType::U8), + ("le", DType::U8) => (strided::le::U8, DType::U8), + ("lt", DType::U8) => (strided::lt::U8, DType::U8), + ("ge", DType::U8) => (strided::ge::U8, DType::U8), + ("gt", DType::U8) => (strided::gt::U8, DType::U8), (name, dtype) => { crate::bail!("Metal strided binary {name} {dtype:?} not implemented") } diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index 30c90ff1..cdc8fef8 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -56,7 +56,9 @@ kernel void FN_NAME_STRIDED( \ #define BINARY_OP(FN, NAME) \ BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \ -BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); +BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \ +BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \ +BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); #define INT64_BINARY_OP(NAME, FN) \ BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided); @@ -66,7 +68,9 @@ BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided); #define BINARY_OP_OUT(NAME, FN) \ BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \ -BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); +BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \ +BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \ +BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); #define INT64_BINARY_OP_OUT(NAME, FN) \ BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided); diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 7b0084d9..d080ef52 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -131,6 +131,8 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64")); + pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32")); + pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8")); } )+ pub mod copy { @@ -153,6 +155,8 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided")); + pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided")); + pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided")); } )+ pub mod copy { diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 38252967..83a56f0a 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -263,21 +263,26 @@ kernel void NAME( REDUCE(x + y, fast_sum_f32_strided, float, 0) REDUCE(x + y, fast_sum_u32_strided, uint, 0) REDUCE(x + y, fast_sum_f16_strided, half, 0) +REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0) REDUCE(x * y, fast_mul_f32_strided, float, 1) REDUCE(x * y, fast_mul_u32_strided, uint, 1) REDUCE(x * y, fast_mul_f16_strided, half, 1) REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF) REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0) REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH) +REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0) REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF) REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF) REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH) +REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF) ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF) ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH) ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF) +ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF) ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF) ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH) ARGMAX(fast_argmax_u32_strided, uint, 0) +ARGMAX(fast_argmax_u8_strided, uint8_t, 0) SOFTMAX(softmax_f32, float) SOFTMAX(softmax_f16, half) diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal index dfa0dd12..40b4bcf4 100644 --- a/candle-metal-kernels/src/ternary.metal +++ b/candle-metal-kernels/src/ternary.metal @@ -55,8 +55,8 @@ kernel void FN_NAME( \ WHERE_OP(float, uint8_t, where_u8_f32) // WHERE_OP(double, uint8_t, where_u8_f64) -// WHERE_OP(uint8_t, uint8_t, where_u8_u8) -// WHERE_OP(uint32_t, uint8_t, where_u8_u32) +WHERE_OP(uint8_t, uint8_t, where_u8_u8) +WHERE_OP(uint32_t, uint8_t, where_u8_u32) #if __METAL_VERSION__ >= 220 WHERE_OP(int64_t, uint8_t, where_u8_i64)