mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
@ -537,6 +537,11 @@ impl BackendStorage for MetalStorage {
|
|||||||
(ReduceOp::Max, DType::I64) => ("fast_max_i64_strided", true, false),
|
(ReduceOp::Max, DType::I64) => ("fast_max_i64_strided", true, false),
|
||||||
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64_strided", true, true),
|
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64_strided", true, true),
|
||||||
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64_strided", true, true),
|
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64_strided", true, true),
|
||||||
|
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8_strided", false, false),
|
||||||
|
(ReduceOp::Min, DType::U8) => ("fast_min_u8_strided", true, false),
|
||||||
|
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
|
||||||
|
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
|
||||||
|
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
|
||||||
(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
|
(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
if check_empty && layout.shape().elem_count() == 0 {
|
if check_empty && layout.shape().elem_count() == 0 {
|
||||||
@ -779,6 +784,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
(DType::U8, DType::F32) => "where_u8_f32",
|
(DType::U8, DType::F32) => "where_u8_f32",
|
||||||
(DType::U8, DType::F16) => "where_u8_f16",
|
(DType::U8, DType::F16) => "where_u8_f16",
|
||||||
(DType::U8, DType::I64) => "where_u8_i64",
|
(DType::U8, DType::I64) => "where_u8_i64",
|
||||||
|
(DType::U8, DType::U32) => "where_u8_u32",
|
||||||
|
(DType::U8, DType::U8) => "where_u8_u8",
|
||||||
(left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"),
|
(left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_where_cond_strided(
|
candle_metal_kernels::call_where_cond_strided(
|
||||||
@ -1323,6 +1330,26 @@ impl MetalStorage {
|
|||||||
("lt", DType::I64) => (contiguous::lt::I64, DType::U8),
|
("lt", DType::I64) => (contiguous::lt::I64, DType::U8),
|
||||||
("ge", DType::I64) => (contiguous::ge::I64, DType::U8),
|
("ge", DType::I64) => (contiguous::ge::I64, DType::U8),
|
||||||
("gt", DType::I64) => (contiguous::gt::I64, DType::U8),
|
("gt", DType::I64) => (contiguous::gt::I64, DType::U8),
|
||||||
|
("add", DType::U32) => (contiguous::add::U32, self.dtype),
|
||||||
|
("sub", DType::U32) => (contiguous::sub::U32, self.dtype),
|
||||||
|
("mul", DType::U32) => (contiguous::mul::U32, self.dtype),
|
||||||
|
("div", DType::U32) => (contiguous::div::U32, self.dtype),
|
||||||
|
("eq", DType::U32) => (contiguous::eq::U32, DType::U8),
|
||||||
|
("ne", DType::U32) => (contiguous::ne::U32, DType::U8),
|
||||||
|
("le", DType::U32) => (contiguous::le::U32, DType::U8),
|
||||||
|
("lt", DType::U32) => (contiguous::lt::U32, DType::U8),
|
||||||
|
("ge", DType::U32) => (contiguous::ge::U32, DType::U8),
|
||||||
|
("gt", DType::U32) => (contiguous::gt::U32, DType::U8),
|
||||||
|
("add", DType::U8) => (contiguous::add::U8, self.dtype),
|
||||||
|
("sub", DType::U8) => (contiguous::sub::U8, self.dtype),
|
||||||
|
("mul", DType::U8) => (contiguous::mul::U8, self.dtype),
|
||||||
|
("div", DType::U8) => (contiguous::div::U8, self.dtype),
|
||||||
|
("eq", DType::U8) => (contiguous::eq::U8, DType::U8),
|
||||||
|
("ne", DType::U8) => (contiguous::ne::U8, DType::U8),
|
||||||
|
("le", DType::U8) => (contiguous::le::U8, DType::U8),
|
||||||
|
("lt", DType::U8) => (contiguous::lt::U8, DType::U8),
|
||||||
|
("ge", DType::U8) => (contiguous::ge::U8, DType::U8),
|
||||||
|
("gt", DType::U8) => (contiguous::gt::U8, DType::U8),
|
||||||
(name, dtype) => {
|
(name, dtype) => {
|
||||||
crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented")
|
crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented")
|
||||||
}
|
}
|
||||||
@ -1380,6 +1407,30 @@ impl MetalStorage {
|
|||||||
("lt", DType::I64) => (strided::lt::I64, DType::U8),
|
("lt", DType::I64) => (strided::lt::I64, DType::U8),
|
||||||
("ge", DType::I64) => (strided::ge::I64, DType::U8),
|
("ge", DType::I64) => (strided::ge::I64, DType::U8),
|
||||||
("gt", DType::I64) => (strided::gt::I64, DType::U8),
|
("gt", DType::I64) => (strided::gt::I64, DType::U8),
|
||||||
|
("badd", DType::U32) => (strided::add::U32, self.dtype),
|
||||||
|
("bsub", DType::U32) => (strided::sub::U32, self.dtype),
|
||||||
|
("bmul", DType::U32) => (strided::mul::U32, self.dtype),
|
||||||
|
("bdiv", DType::U32) => (strided::div::U32, self.dtype),
|
||||||
|
("bminimum", DType::U32) => (strided::min::U32, self.dtype),
|
||||||
|
("bmaximum", DType::U32) => (strided::max::U32, self.dtype),
|
||||||
|
("eq", DType::U32) => (strided::eq::U32, DType::U8),
|
||||||
|
("ne", DType::U32) => (strided::ne::U32, DType::U8),
|
||||||
|
("le", DType::U32) => (strided::le::U32, DType::U8),
|
||||||
|
("lt", DType::U32) => (strided::lt::U32, DType::U8),
|
||||||
|
("ge", DType::U32) => (strided::ge::U32, DType::U8),
|
||||||
|
("gt", DType::U32) => (strided::gt::U32, DType::U8),
|
||||||
|
("badd", DType::U8) => (strided::add::U8, self.dtype),
|
||||||
|
("bsub", DType::U8) => (strided::sub::U8, self.dtype),
|
||||||
|
("bmul", DType::U8) => (strided::mul::U8, self.dtype),
|
||||||
|
("bdiv", DType::U8) => (strided::div::U8, self.dtype),
|
||||||
|
("bminimum", DType::U8) => (strided::min::U8, self.dtype),
|
||||||
|
("bmaximum", DType::U8) => (strided::max::U8, self.dtype),
|
||||||
|
("eq", DType::U8) => (strided::eq::U8, DType::U8),
|
||||||
|
("ne", DType::U8) => (strided::ne::U8, DType::U8),
|
||||||
|
("le", DType::U8) => (strided::le::U8, DType::U8),
|
||||||
|
("lt", DType::U8) => (strided::lt::U8, DType::U8),
|
||||||
|
("ge", DType::U8) => (strided::ge::U8, DType::U8),
|
||||||
|
("gt", DType::U8) => (strided::gt::U8, DType::U8),
|
||||||
(name, dtype) => {
|
(name, dtype) => {
|
||||||
crate::bail!("Metal strided binary {name} {dtype:?} not implemented")
|
crate::bail!("Metal strided binary {name} {dtype:?} not implemented")
|
||||||
}
|
}
|
||||||
|
@ -56,7 +56,9 @@ kernel void FN_NAME_STRIDED( \
|
|||||||
|
|
||||||
#define BINARY_OP(FN, NAME) \
|
#define BINARY_OP(FN, NAME) \
|
||||||
BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \
|
BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \
|
||||||
BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided);
|
BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \
|
||||||
|
BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \
|
||||||
|
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
|
||||||
|
|
||||||
#define INT64_BINARY_OP(NAME, FN) \
|
#define INT64_BINARY_OP(NAME, FN) \
|
||||||
BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
|
BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
|
||||||
@ -66,7 +68,9 @@ BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
|||||||
|
|
||||||
#define BINARY_OP_OUT(NAME, FN) \
|
#define BINARY_OP_OUT(NAME, FN) \
|
||||||
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
|
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
|
||||||
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided);
|
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \
|
||||||
|
BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
|
||||||
|
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
|
||||||
|
|
||||||
#define INT64_BINARY_OP_OUT(NAME, FN) \
|
#define INT64_BINARY_OP_OUT(NAME, FN) \
|
||||||
BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided);
|
BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided);
|
||||||
|
@ -131,6 +131,8 @@ macro_rules! ops{
|
|||||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
|
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
|
||||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
|
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
|
||||||
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64"));
|
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64"));
|
||||||
|
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32"));
|
||||||
|
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8"));
|
||||||
}
|
}
|
||||||
)+
|
)+
|
||||||
pub mod copy {
|
pub mod copy {
|
||||||
@ -153,6 +155,8 @@ macro_rules! ops{
|
|||||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
|
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
|
||||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
|
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
|
||||||
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided"));
|
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided"));
|
||||||
|
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided"));
|
||||||
|
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided"));
|
||||||
}
|
}
|
||||||
)+
|
)+
|
||||||
pub mod copy {
|
pub mod copy {
|
||||||
|
@ -263,21 +263,26 @@ kernel void NAME(
|
|||||||
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
||||||
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
||||||
REDUCE(x + y, fast_sum_f16_strided, half, 0)
|
REDUCE(x + y, fast_sum_f16_strided, half, 0)
|
||||||
|
REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0)
|
||||||
REDUCE(x * y, fast_mul_f32_strided, float, 1)
|
REDUCE(x * y, fast_mul_f32_strided, float, 1)
|
||||||
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
|
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
|
||||||
REDUCE(x * y, fast_mul_f16_strided, half, 1)
|
REDUCE(x * y, fast_mul_f16_strided, half, 1)
|
||||||
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
|
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
|
||||||
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
|
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
|
||||||
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
|
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
|
||||||
|
REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0)
|
||||||
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
|
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
|
||||||
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
|
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
|
||||||
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
|
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
|
||||||
|
REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF)
|
||||||
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
|
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
|
||||||
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
|
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
|
||||||
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
|
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
|
||||||
|
ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF)
|
||||||
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
|
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
|
||||||
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
|
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
|
||||||
ARGMAX(fast_argmax_u32_strided, uint, 0)
|
ARGMAX(fast_argmax_u32_strided, uint, 0)
|
||||||
|
ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
|
||||||
|
|
||||||
SOFTMAX(softmax_f32, float)
|
SOFTMAX(softmax_f32, float)
|
||||||
SOFTMAX(softmax_f16, half)
|
SOFTMAX(softmax_f16, half)
|
||||||
|
@ -55,8 +55,8 @@ kernel void FN_NAME( \
|
|||||||
|
|
||||||
WHERE_OP(float, uint8_t, where_u8_f32)
|
WHERE_OP(float, uint8_t, where_u8_f32)
|
||||||
// WHERE_OP(double, uint8_t, where_u8_f64)
|
// WHERE_OP(double, uint8_t, where_u8_f64)
|
||||||
// WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
||||||
// WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 220
|
#if __METAL_VERSION__ >= 220
|
||||||
WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
||||||
|
Reference in New Issue
Block a user