diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index deb7a401..fa6973b4 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -443,42 +443,60 @@ impl BackendStorage for MetalStorage { use candle_metal_kernels::unary::contiguous; let kernel_name = match (B::KERNEL, dtype) { - ("ucos", DType::F32) => contiguous::cos::FLOAT, - ("usin", DType::F32) => contiguous::sin::FLOAT, - ("usqr", DType::F32) => contiguous::sqr::FLOAT, - ("usqrt", DType::F32) => contiguous::sqrt::FLOAT, - ("uneg", DType::F32) => contiguous::neg::FLOAT, - ("uexp", DType::F32) => contiguous::exp::FLOAT, - ("ulog", DType::F32) => contiguous::log::FLOAT, - ("ugelu", DType::F32) => contiguous::gelu::FLOAT, - ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, - ("uerf", DType::F32) => contiguous::erf::FLOAT, - ("usilu", DType::F32) => contiguous::silu::FLOAT, - ("uabs", DType::F32) => contiguous::abs::FLOAT, - ("uceil", DType::F32) => contiguous::ceil::FLOAT, - ("ufloor", DType::F32) => contiguous::floor::FLOAT, - ("uround", DType::F32) => contiguous::round::FLOAT, - ("urecip", DType::F32) => contiguous::recip::FLOAT, - ("utanh", DType::F32) => contiguous::tanh::FLOAT, - ("urelu", DType::F32) => contiguous::relu::FLOAT, - ("ucos", DType::F16) => contiguous::cos::HALF, - ("usin", DType::F16) => contiguous::sin::HALF, - ("usqr", DType::F16) => contiguous::sqr::HALF, - ("usqrt", DType::F16) => contiguous::sqrt::HALF, - ("uneg", DType::F16) => contiguous::neg::HALF, - ("uexp", DType::F16) => contiguous::exp::HALF, - ("ulog", DType::F16) => contiguous::log::HALF, - ("ugelu", DType::F16) => contiguous::gelu::HALF, - ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, - ("uerf", DType::F16) => contiguous::erf::HALF, - ("usilu", DType::F16) => contiguous::silu::HALF, ("uabs", DType::F16) => contiguous::abs::HALF, + ("uabs", DType::F32) => contiguous::abs::FLOAT, + ("uabs", DType::BF16) => contiguous::abs::BFLOAT, ("uceil", DType::F16) => contiguous::ceil::HALF, + ("uceil", DType::F32) => contiguous::ceil::FLOAT, + ("uceil", DType::BF16) => contiguous::ceil::BFLOAT, + ("ucos", DType::F16) => contiguous::cos::HALF, + ("ucos", DType::F32) => contiguous::cos::FLOAT, + ("ucos", DType::BF16) => contiguous::cos::BFLOAT, + ("uerf", DType::F16) => contiguous::erf::HALF, + ("uerf", DType::F32) => contiguous::erf::FLOAT, + ("uerf", DType::BF16) => contiguous::erf::BFLOAT, + ("uexp", DType::F16) => contiguous::exp::HALF, + ("uexp", DType::F32) => contiguous::exp::FLOAT, + ("uexp", DType::BF16) => contiguous::exp::BFLOAT, ("ufloor", DType::F16) => contiguous::floor::HALF, - ("uround", DType::F16) => contiguous::round::HALF, + ("ufloor", DType::F32) => contiguous::floor::FLOAT, + ("ufloor", DType::BF16) => contiguous::floor::BFLOAT, + ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, + ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, + ("ugelu_erf", DType::BF16) => contiguous::gelu_erf::BFLOAT, + ("ugelu", DType::F16) => contiguous::gelu::HALF, + ("ugelu", DType::F32) => contiguous::gelu::FLOAT, + ("ugelu", DType::BF16) => contiguous::gelu::BFLOAT, + ("ulog", DType::F16) => contiguous::log::HALF, + ("ulog", DType::F32) => contiguous::log::FLOAT, + ("ulog", DType::BF16) => contiguous::log::BFLOAT, + ("uneg", DType::F16) => contiguous::neg::HALF, + ("uneg", DType::F32) => contiguous::neg::FLOAT, + ("uneg", DType::BF16) => contiguous::neg::BFLOAT, ("urecip", DType::F16) => contiguous::recip::HALF, - ("utanh", DType::F16) => contiguous::tanh::HALF, + ("urecip", DType::F32) => contiguous::recip::FLOAT, + ("urecip", DType::BF16) => contiguous::recip::BFLOAT, ("urelu", DType::F16) => contiguous::relu::HALF, + ("urelu", DType::F32) => contiguous::relu::FLOAT, + ("urelu", DType::BF16) => contiguous::relu::BFLOAT, + ("uround", DType::F16) => contiguous::round::HALF, + ("uround", DType::F32) => contiguous::round::FLOAT, + ("uround", DType::BF16) => contiguous::round::BFLOAT, + ("usilu", DType::F16) => contiguous::silu::HALF, + ("usilu", DType::F32) => contiguous::silu::FLOAT, + ("usilu", DType::BF16) => contiguous::silu::BFLOAT, + ("usin", DType::F16) => contiguous::sin::HALF, + ("usin", DType::F32) => contiguous::sin::FLOAT, + ("usin", DType::BF16) => contiguous::sin::BFLOAT, + ("usqr", DType::F16) => contiguous::sqr::HALF, + ("usqr", DType::F32) => contiguous::sqr::FLOAT, + ("usqr", DType::BF16) => contiguous::sqr::BFLOAT, + ("usqrt", DType::F16) => contiguous::sqrt::HALF, + ("usqrt", DType::F32) => contiguous::sqrt::FLOAT, + ("usqrt", DType::BF16) => contiguous::sqrt::BFLOAT, + ("utanh", DType::F16) => contiguous::tanh::HALF, + ("utanh", DType::F32) => contiguous::tanh::FLOAT, + ("utanh", DType::BF16) => contiguous::tanh::BFLOAT, (name, dtype) => { crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") } diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index ae11286a..e83498e4 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -60,21 +60,24 @@ BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \ BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \ BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); -#define INT64_BINARY_OP(NAME, FN) \ -BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided); - -#define BFLOAT_BINARY_OP(FN, NAME) \ -BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided); - #define BINARY_OP_OUT(NAME, FN) \ BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \ BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \ BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \ BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); +#define INT64_BINARY_OP(NAME, FN) \ +BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided); + #define INT64_BINARY_OP_OUT(NAME, FN) \ BINARY(FN, int64_t, uint8_t, NAME##_i64, NAME##_i64_strided); +#define BFLOAT_BINARY_OP(FN, NAME) \ +BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided); + +#define BFLOAT_BINARY_OP_OUT(NAME, FN) \ +BINARY(FN, bfloat, uint8_t, NAME##_bf16, NAME##_bf16_strided); + BINARY_OP(x + y, add) BINARY_OP(x - y, sub) BINARY_OP(x * y, mul) @@ -112,4 +115,11 @@ BFLOAT_BINARY_OP(x * y, mul) BFLOAT_BINARY_OP(x / y, div) BFLOAT_BINARY_OP(MIN(x, y), min) BFLOAT_BINARY_OP(MAX(x, y), max) + +BFLOAT_BINARY_OP_OUT(eq, x == y) +BFLOAT_BINARY_OP_OUT(ne, x != y) +BFLOAT_BINARY_OP_OUT(le, x <= y) +BFLOAT_BINARY_OP_OUT(lt, x < y) +BFLOAT_BINARY_OP_OUT(ge, x >= y) +BFLOAT_BINARY_OP_OUT(gt, x > y) #endif diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 561d1744..acb69299 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -484,9 +484,13 @@ ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN) #if defined(__HAVE_BFLOAT__) REDUCE(x + y, fast_sum_bf16, bfloat, 0) +REDUCE(x + y, fast_sum_bf16_strided, half, 0) REDUCE(x * y, fast_mul_bf16, bfloat, 1) +REDUCE(x * y, fast_mul_bf16_strided, bfloat, 1) REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF) +REDUCE(MAX(x, y), fast_max_bf16_strided, bfloat, -HUGE_VALBF) REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF) +REDUCE(MIN(x, y), fast_min_bf16_strided, bfloat, HUGE_VALBF) ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF) ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) SOFTMAX(softmax_bf16, bfloat)