mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Metal: i64 basic support (#1495)
* Adds basic metal i64 support * metal copy i64
This commit is contained in:
@ -532,6 +532,11 @@ impl BackendStorage for MetalStorage {
|
||||
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false),
|
||||
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true),
|
||||
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true),
|
||||
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64_strided", false, false),
|
||||
(ReduceOp::Min, DType::I64) => ("fast_min_i64_strided", true, false),
|
||||
(ReduceOp::Max, DType::I64) => ("fast_max_i64_strided", true, false),
|
||||
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64_strided", true, true),
|
||||
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64_strided", true, true),
|
||||
(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
@ -579,10 +584,13 @@ impl BackendStorage for MetalStorage {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8",
|
||||
(DType::U32, DType::I64) => "cast_u32_i64",
|
||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||
(DType::U8, DType::F32) => "cast_u8_f32",
|
||||
(DType::U8, DType::I64) => "cast_u8_i64",
|
||||
(DType::F32, DType::F16) => "cast_f32_f16",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32",
|
||||
(DType::I64, DType::F32) => "cast_i64_f32",
|
||||
(left, right) => {
|
||||
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
|
||||
}
|
||||
@ -602,10 +610,13 @@ impl BackendStorage for MetalStorage {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32_strided",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8_strided",
|
||||
(DType::U32, DType::I64) => "cast_u32_i64_strided",
|
||||
(DType::U8, DType::U32) => "cast_u8_u32_strided",
|
||||
(DType::U8, DType::F32) => "cast_u8_f32_strided",
|
||||
(DType::U8, DType::I64) => "cast_u8_i64_strided",
|
||||
(DType::F32, DType::F16) => "cast_f32_f16_strided",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32_strided",
|
||||
(DType::I64, DType::F32) => "cast_i64_f32_strided",
|
||||
(left, right) => {
|
||||
crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented")
|
||||
}
|
||||
@ -767,6 +778,7 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match (self.dtype, t.dtype()) {
|
||||
(DType::U8, DType::F32) => "where_u8_f32",
|
||||
(DType::U8, DType::F16) => "where_u8_f16",
|
||||
(DType::U8, DType::I64) => "where_u8_i64",
|
||||
(left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_where_cond_strided(
|
||||
@ -1226,6 +1238,7 @@ impl BackendStorage for MetalStorage {
|
||||
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
||||
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
|
||||
DType::I64 => candle_metal_kernels::unary::strided::copy::I64,
|
||||
DType::U32 => candle_metal_kernels::unary::strided::copy::U32,
|
||||
DType::U8 => candle_metal_kernels::unary::strided::copy::U8,
|
||||
dtype => crate::bail!("Metal copy_strided {dtype:?} not implemented"),
|
||||
@ -1300,6 +1313,16 @@ impl MetalStorage {
|
||||
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
|
||||
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
|
||||
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
|
||||
("add", DType::I64) => (contiguous::add::I64, self.dtype),
|
||||
("sub", DType::I64) => (contiguous::sub::I64, self.dtype),
|
||||
("mul", DType::I64) => (contiguous::mul::I64, self.dtype),
|
||||
("div", DType::I64) => (contiguous::div::I64, self.dtype),
|
||||
("eq", DType::I64) => (contiguous::eq::I64, DType::U8),
|
||||
("ne", DType::I64) => (contiguous::ne::I64, DType::U8),
|
||||
("le", DType::I64) => (contiguous::le::I64, DType::U8),
|
||||
("lt", DType::I64) => (contiguous::lt::I64, DType::U8),
|
||||
("ge", DType::I64) => (contiguous::ge::I64, DType::U8),
|
||||
("gt", DType::I64) => (contiguous::gt::I64, DType::U8),
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented")
|
||||
}
|
||||
@ -1345,6 +1368,18 @@ impl MetalStorage {
|
||||
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
|
||||
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
|
||||
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
|
||||
("badd", DType::I64) => (strided::add::I64, self.dtype),
|
||||
("bsub", DType::I64) => (strided::sub::I64, self.dtype),
|
||||
("bmul", DType::I64) => (strided::mul::I64, self.dtype),
|
||||
("bdiv", DType::I64) => (strided::div::I64, self.dtype),
|
||||
("bminimum", DType::I64) => (strided::min::I64, self.dtype),
|
||||
("bmaximum", DType::I64) => (strided::max::I64, self.dtype),
|
||||
("eq", DType::I64) => (strided::eq::I64, DType::U8),
|
||||
("ne", DType::I64) => (strided::ne::I64, DType::U8),
|
||||
("le", DType::I64) => (strided::le::I64, DType::U8),
|
||||
("lt", DType::I64) => (strided::lt::I64, DType::U8),
|
||||
("ge", DType::I64) => (strided::ge::I64, DType::U8),
|
||||
("gt", DType::I64) => (strided::gt::I64, DType::U8),
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal strided binary {name} {dtype:?} not implemented")
|
||||
}
|
||||
|
@ -58,6 +58,9 @@ kernel void FN_NAME_STRIDED( \
|
||||
BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \
|
||||
BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided);
|
||||
|
||||
#define INT64_BINARY_OP(NAME, FN) \
|
||||
BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
|
||||
|
||||
#define BFLOAT_BINARY_OP(FN, NAME) \
|
||||
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||
|
||||
@ -65,6 +68,8 @@ BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
|
||||
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided);
|
||||
|
||||
#define INT64_BINARY_OP_OUT(NAME, FN) \
|
||||
BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided);
|
||||
|
||||
BINARY_OP(x + y, add)
|
||||
BINARY_OP(x - y, sub)
|
||||
@ -80,6 +85,22 @@ BINARY_OP_OUT(lt, x < y)
|
||||
BINARY_OP_OUT(ge, x >= y)
|
||||
BINARY_OP_OUT(gt, x > y)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
INT64_BINARY_OP(add, x + y)
|
||||
INT64_BINARY_OP(sub, x - y)
|
||||
INT64_BINARY_OP(mul, x * y)
|
||||
INT64_BINARY_OP(div, x / y)
|
||||
INT64_BINARY_OP(min, MIN(x, y))
|
||||
INT64_BINARY_OP(max, MAX(x, y))
|
||||
|
||||
INT64_BINARY_OP_OUT(eq, x == y)
|
||||
INT64_BINARY_OP_OUT(ne, x != y)
|
||||
INT64_BINARY_OP_OUT(le, x <= y)
|
||||
INT64_BINARY_OP_OUT(lt, x < y)
|
||||
INT64_BINARY_OP_OUT(ge, x >= y)
|
||||
INT64_BINARY_OP_OUT(gt, x > y)
|
||||
#endif
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
BFLOAT_BINARY_OP(x + y, add)
|
||||
BFLOAT_BINARY_OP(x - y, sub)
|
||||
|
@ -52,5 +52,11 @@ CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
|
||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
|
||||
CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
|
||||
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
|
||||
#endif
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#endif
|
||||
|
@ -130,6 +130,7 @@ macro_rules! ops{
|
||||
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32"));
|
||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
|
||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
|
||||
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64"));
|
||||
}
|
||||
)+
|
||||
pub mod copy {
|
||||
@ -137,6 +138,7 @@ macro_rules! ops{
|
||||
pub const FLOAT: Kernel = Kernel("copy_f32");
|
||||
pub const HALF: Kernel = Kernel("copy_f16");
|
||||
pub const BFLOAT: Kernel = Kernel("copy_bf16");
|
||||
pub const I64: Kernel = Kernel("copy_i64");
|
||||
pub const U32: Kernel = Kernel("copy_u32");
|
||||
pub const U8: Kernel = Kernel("copy_u8");
|
||||
}
|
||||
@ -150,6 +152,7 @@ macro_rules! ops{
|
||||
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided"));
|
||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
|
||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
|
||||
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided"));
|
||||
}
|
||||
)+
|
||||
pub mod copy {
|
||||
@ -157,6 +160,7 @@ macro_rules! ops{
|
||||
pub const FLOAT: Kernel = Kernel("copy_f32_strided");
|
||||
pub const HALF: Kernel = Kernel("copy_f16_strided");
|
||||
pub const BFLOAT: Kernel = Kernel("copy_bf16_strided");
|
||||
pub const I64: Kernel = Kernel("copy_i64_strided");
|
||||
pub const U32: Kernel = Kernel("copy_u32_strided");
|
||||
pub const U8: Kernel = Kernel("copy_u8_strided");
|
||||
}
|
||||
|
@ -281,6 +281,15 @@ ARGMAX(fast_argmax_u32_strided, uint, 0)
|
||||
|
||||
SOFTMAX(softmax_f32, float)
|
||||
SOFTMAX(softmax_f16, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX)
|
||||
REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN)
|
||||
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
|
||||
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
|
||||
#endif
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
||||
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
||||
|
@ -57,4 +57,7 @@ WHERE_OP(float, uint8_t, where_u8_f32)
|
||||
// WHERE_OP(double, uint8_t, where_u8_f64)
|
||||
// WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
||||
// WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
||||
// WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
||||
#endif
|
||||
|
@ -115,6 +115,10 @@ UNARY(id, half, copy_f16, copy_f16_strided)
|
||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
UNARY(id, int64_t, copy_i64, copy_i64_strided)
|
||||
#endif
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
BFLOAT_UNARY_OP(cos)
|
||||
BFLOAT_UNARY_OP(sin)
|
||||
|
Reference in New Issue
Block a user