mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
update dtypes checks for several metal operations (#2010)
This commit is contained in:
@ -443,42 +443,60 @@ impl BackendStorage for MetalStorage {
|
|||||||
use candle_metal_kernels::unary::contiguous;
|
use candle_metal_kernels::unary::contiguous;
|
||||||
|
|
||||||
let kernel_name = match (B::KERNEL, dtype) {
|
let kernel_name = match (B::KERNEL, dtype) {
|
||||||
("ucos", DType::F32) => contiguous::cos::FLOAT,
|
|
||||||
("usin", DType::F32) => contiguous::sin::FLOAT,
|
|
||||||
("usqr", DType::F32) => contiguous::sqr::FLOAT,
|
|
||||||
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
|
|
||||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
|
||||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
|
||||||
("ulog", DType::F32) => contiguous::log::FLOAT,
|
|
||||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
|
||||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
|
||||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
|
||||||
("usilu", DType::F32) => contiguous::silu::FLOAT,
|
|
||||||
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
|
||||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
|
||||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
|
||||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
|
||||||
("urecip", DType::F32) => contiguous::recip::FLOAT,
|
|
||||||
("utanh", DType::F32) => contiguous::tanh::FLOAT,
|
|
||||||
("urelu", DType::F32) => contiguous::relu::FLOAT,
|
|
||||||
("ucos", DType::F16) => contiguous::cos::HALF,
|
|
||||||
("usin", DType::F16) => contiguous::sin::HALF,
|
|
||||||
("usqr", DType::F16) => contiguous::sqr::HALF,
|
|
||||||
("usqrt", DType::F16) => contiguous::sqrt::HALF,
|
|
||||||
("uneg", DType::F16) => contiguous::neg::HALF,
|
|
||||||
("uexp", DType::F16) => contiguous::exp::HALF,
|
|
||||||
("ulog", DType::F16) => contiguous::log::HALF,
|
|
||||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
|
||||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
|
||||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
|
||||||
("usilu", DType::F16) => contiguous::silu::HALF,
|
|
||||||
("uabs", DType::F16) => contiguous::abs::HALF,
|
("uabs", DType::F16) => contiguous::abs::HALF,
|
||||||
|
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
||||||
|
("uabs", DType::BF16) => contiguous::abs::BFLOAT,
|
||||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||||
|
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||||
|
("uceil", DType::BF16) => contiguous::ceil::BFLOAT,
|
||||||
|
("ucos", DType::F16) => contiguous::cos::HALF,
|
||||||
|
("ucos", DType::F32) => contiguous::cos::FLOAT,
|
||||||
|
("ucos", DType::BF16) => contiguous::cos::BFLOAT,
|
||||||
|
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||||
|
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||||
|
("uerf", DType::BF16) => contiguous::erf::BFLOAT,
|
||||||
|
("uexp", DType::F16) => contiguous::exp::HALF,
|
||||||
|
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||||
|
("uexp", DType::BF16) => contiguous::exp::BFLOAT,
|
||||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||||
("uround", DType::F16) => contiguous::round::HALF,
|
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||||
|
("ufloor", DType::BF16) => contiguous::floor::BFLOAT,
|
||||||
|
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||||
|
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||||
|
("ugelu_erf", DType::BF16) => contiguous::gelu_erf::BFLOAT,
|
||||||
|
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||||
|
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||||
|
("ugelu", DType::BF16) => contiguous::gelu::BFLOAT,
|
||||||
|
("ulog", DType::F16) => contiguous::log::HALF,
|
||||||
|
("ulog", DType::F32) => contiguous::log::FLOAT,
|
||||||
|
("ulog", DType::BF16) => contiguous::log::BFLOAT,
|
||||||
|
("uneg", DType::F16) => contiguous::neg::HALF,
|
||||||
|
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||||
|
("uneg", DType::BF16) => contiguous::neg::BFLOAT,
|
||||||
("urecip", DType::F16) => contiguous::recip::HALF,
|
("urecip", DType::F16) => contiguous::recip::HALF,
|
||||||
("utanh", DType::F16) => contiguous::tanh::HALF,
|
("urecip", DType::F32) => contiguous::recip::FLOAT,
|
||||||
|
("urecip", DType::BF16) => contiguous::recip::BFLOAT,
|
||||||
("urelu", DType::F16) => contiguous::relu::HALF,
|
("urelu", DType::F16) => contiguous::relu::HALF,
|
||||||
|
("urelu", DType::F32) => contiguous::relu::FLOAT,
|
||||||
|
("urelu", DType::BF16) => contiguous::relu::BFLOAT,
|
||||||
|
("uround", DType::F16) => contiguous::round::HALF,
|
||||||
|
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||||
|
("uround", DType::BF16) => contiguous::round::BFLOAT,
|
||||||
|
("usilu", DType::F16) => contiguous::silu::HALF,
|
||||||
|
("usilu", DType::F32) => contiguous::silu::FLOAT,
|
||||||
|
("usilu", DType::BF16) => contiguous::silu::BFLOAT,
|
||||||
|
("usin", DType::F16) => contiguous::sin::HALF,
|
||||||
|
("usin", DType::F32) => contiguous::sin::FLOAT,
|
||||||
|
("usin", DType::BF16) => contiguous::sin::BFLOAT,
|
||||||
|
("usqr", DType::F16) => contiguous::sqr::HALF,
|
||||||
|
("usqr", DType::F32) => contiguous::sqr::FLOAT,
|
||||||
|
("usqr", DType::BF16) => contiguous::sqr::BFLOAT,
|
||||||
|
("usqrt", DType::F16) => contiguous::sqrt::HALF,
|
||||||
|
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
|
||||||
|
("usqrt", DType::BF16) => contiguous::sqrt::BFLOAT,
|
||||||
|
("utanh", DType::F16) => contiguous::tanh::HALF,
|
||||||
|
("utanh", DType::F32) => contiguous::tanh::FLOAT,
|
||||||
|
("utanh", DType::BF16) => contiguous::tanh::BFLOAT,
|
||||||
(name, dtype) => {
|
(name, dtype) => {
|
||||||
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
|
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
|
||||||
}
|
}
|
||||||
|
@ -60,21 +60,24 @@ BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \
|
|||||||
BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \
|
BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \
|
||||||
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
|
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
|
||||||
|
|
||||||
#define INT64_BINARY_OP(NAME, FN) \
|
|
||||||
BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
|
|
||||||
|
|
||||||
#define BFLOAT_BINARY_OP(FN, NAME) \
|
|
||||||
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
|
||||||
|
|
||||||
#define BINARY_OP_OUT(NAME, FN) \
|
#define BINARY_OP_OUT(NAME, FN) \
|
||||||
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
|
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
|
||||||
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \
|
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \
|
||||||
BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
|
BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
|
||||||
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
|
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
|
||||||
|
|
||||||
|
#define INT64_BINARY_OP(NAME, FN) \
|
||||||
|
BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
|
||||||
|
|
||||||
#define INT64_BINARY_OP_OUT(NAME, FN) \
|
#define INT64_BINARY_OP_OUT(NAME, FN) \
|
||||||
BINARY(FN, int64_t, uint8_t, NAME##_i64, NAME##_i64_strided);
|
BINARY(FN, int64_t, uint8_t, NAME##_i64, NAME##_i64_strided);
|
||||||
|
|
||||||
|
#define BFLOAT_BINARY_OP(FN, NAME) \
|
||||||
|
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||||
|
|
||||||
|
#define BFLOAT_BINARY_OP_OUT(NAME, FN) \
|
||||||
|
BINARY(FN, bfloat, uint8_t, NAME##_bf16, NAME##_bf16_strided);
|
||||||
|
|
||||||
BINARY_OP(x + y, add)
|
BINARY_OP(x + y, add)
|
||||||
BINARY_OP(x - y, sub)
|
BINARY_OP(x - y, sub)
|
||||||
BINARY_OP(x * y, mul)
|
BINARY_OP(x * y, mul)
|
||||||
@ -112,4 +115,11 @@ BFLOAT_BINARY_OP(x * y, mul)
|
|||||||
BFLOAT_BINARY_OP(x / y, div)
|
BFLOAT_BINARY_OP(x / y, div)
|
||||||
BFLOAT_BINARY_OP(MIN(x, y), min)
|
BFLOAT_BINARY_OP(MIN(x, y), min)
|
||||||
BFLOAT_BINARY_OP(MAX(x, y), max)
|
BFLOAT_BINARY_OP(MAX(x, y), max)
|
||||||
|
|
||||||
|
BFLOAT_BINARY_OP_OUT(eq, x == y)
|
||||||
|
BFLOAT_BINARY_OP_OUT(ne, x != y)
|
||||||
|
BFLOAT_BINARY_OP_OUT(le, x <= y)
|
||||||
|
BFLOAT_BINARY_OP_OUT(lt, x < y)
|
||||||
|
BFLOAT_BINARY_OP_OUT(ge, x >= y)
|
||||||
|
BFLOAT_BINARY_OP_OUT(gt, x > y)
|
||||||
#endif
|
#endif
|
||||||
|
@ -484,9 +484,13 @@ ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
|
|||||||
|
|
||||||
#if defined(__HAVE_BFLOAT__)
|
#if defined(__HAVE_BFLOAT__)
|
||||||
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
||||||
|
REDUCE(x + y, fast_sum_bf16_strided, half, 0)
|
||||||
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
||||||
|
REDUCE(x * y, fast_mul_bf16_strided, bfloat, 1)
|
||||||
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
|
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
|
||||||
|
REDUCE(MAX(x, y), fast_max_bf16_strided, bfloat, -HUGE_VALBF)
|
||||||
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
|
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
|
||||||
|
REDUCE(MIN(x, y), fast_min_bf16_strided, bfloat, HUGE_VALBF)
|
||||||
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
||||||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||||
SOFTMAX(softmax_bf16, bfloat)
|
SOFTMAX(softmax_bf16, bfloat)
|
||||||
|
Reference in New Issue
Block a user