diff --git a/candle-metal-kernels/compile.sh b/candle-metal-kernels/compile.sh new file mode 100644 index 00000000..04d1c4d2 --- /dev/null +++ b/candle-metal-kernels/compile.sh @@ -0,0 +1,2 @@ +xcrun metal -c src/gemm/kernels/steel_gemm.metal -I src/ +xcrun metallib steel_gemm.air -o src/gemm/steel_gemm.metallib diff --git a/candle-metal-kernels/src/gemm/bf16.h b/candle-metal-kernels/src/gemm/bf16.h new file mode 100644 index 00000000..cc05998f --- /dev/null +++ b/candle-metal-kernels/src/gemm/bf16.h @@ -0,0 +1,317 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +using namespace metal; + +#if defined(__HAVE_BFLOAT__) + +typedef bfloat bfloat16_t; + +#else + +///////////////////////////////////////////////////////////////////////////// +// Helpers +///////////////////////////////////////////////////////////////////////////// + +constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) { + // Check for nan + if ((as_type(x) & ~_fp_encoding_traits::sign_mask) > + _fp_encoding_traits::inf_mask) { + return uint16_t(as_type(0x7FC0)); + } + // Take bits + uint32_t float_bits = as_type(x); + + // Round to nearest even + float_bits += ((float_bits >> 16) & 1) + as_type(0x7FFF); + + // Take upper 16 bits + return float_bits >> 16; +} + +constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) { + // Upper 16 bits are the data and lower 16 bits are 0s + return as_type((uint32_t)x << 16); +} + +struct _MLX_BFloat16; + +template +static constexpr constant bool can_convert_to_bfloat = + !is_same_v && is_convertible_v; + +template +static constexpr constant bool can_convert_from_bfloat = + !is_same_v && is_convertible_v; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat struct +///////////////////////////////////////////////////////////////////////////// + +struct _MLX_BFloat16 { + ///////////////////////////////////////////////////////////////////////////// + // Constructors + uint16_t bits_; + _MLX_BFloat16() thread = default; + _MLX_BFloat16() threadgroup = default; + _MLX_BFloat16() device = default; + _MLX_BFloat16() constant = default; + + struct bits_to_bfloat_struct {}; + static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() { + return bits_to_bfloat_struct(); + } + constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct) + : bits_(bits) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions to bfloat + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) thread + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) device + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) constant + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions from bfloat + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const thread { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const threadgroup { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const device { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const constant { + return static_cast(bfloat_bits_to_float(bits_)); + } +}; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat operators +///////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////// +// Unary ops +constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) { + return -static_cast(x); +} + +///////////////////////////////////////////////////////////////////////////// +// Binary operators +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +///////////////////////////////////////////////////////////////////////////// +// Arithmetic Operators +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base( \ + _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, float, half, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); + +bfloat_binop(+, operator+); +bfloat_binop(-, operator-); +bfloat_binop(*, operator*); +bfloat_binop(/, operator/); + +///////////////////////////////////////////////////////////////////////////// +// Comparison ops +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base( \ + __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, half, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); + +bfloat_compop(>, operator>); +bfloat_compop(<, operator<); +bfloat_compop(>=, operator>=); +bfloat_compop(<=, operator<=); +bfloat_compop(==, operator==); +bfloat_compop(!=, operator!=); + +#undef bfloat_compop +#undef bfloat_binop_base +#undef bfloat_binop_helper +#undef bfloat_binop + +///////////////////////////////////////////////////////////////////////////// +// Inplace Operators +#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ + addr_space _MLX_BFloat16& lhs, itype rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } \ + constexpr METAL_FUNC addr_space itype& __operator__( \ + addr_space itype& lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ + bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup); + +#define bfloat_inplace_op(itype) \ + bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ + bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ + bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ + bfloat_inplace_op_addr_space_helper(/, operator/=, itype); + +bfloat_inplace_op(float); +bfloat_inplace_op(half); +bfloat_inplace_op(int16_t); +bfloat_inplace_op(int32_t); +bfloat_inplace_op(int64_t); +bfloat_inplace_op(uint16_t); +bfloat_inplace_op(uint32_t); +bfloat_inplace_op(uint64_t); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper +#undef bfloat_inplace_op + +#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ + addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ + bfloat_inplace_op_helper(__op__, __operator__, device); \ + bfloat_inplace_op_helper(__op__, __operator__, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, threadgroup); + +bfloat_inplace_op_addr_space_helper(+, operator+=); +bfloat_inplace_op_addr_space_helper(-, operator-=); +bfloat_inplace_op_addr_space_helper(*, operator*=); +bfloat_inplace_op_addr_space_helper(/, operator/=); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper + +///////////////////////////////////////////////////////////////////////////// +// Bfloat typedef +///////////////////////////////////////////////////////////////////////////// + +typedef struct _MLX_BFloat16 bfloat16_t; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat numeric limits +///////////////////////////////////////////////////////////////////////////// + +#pragma METAL internals : enable + +namespace metal { + +template <> +struct _numeric_limits_impl : _fp_numeric_limits_impl_base { + static constexpr constant int digits = 8; + static constexpr constant int digits10 = 2; + static constexpr constant int max_digits10 = 4; + static constexpr constant int radix = 2; + static constexpr constant int min_exponent = -125; + static constexpr constant int min_exponent10 = -37; + static constexpr constant int max_exponent = 128; + static constexpr constant int max_exponent10 = 38; + + static constexpr bfloat16_t min() { + return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t lowest() { + return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t max() { + return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t epsilon() { + return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t round_error() { + return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t infinity() { + return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t quiet_NaN() { + return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t signaling_NaN() { + return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t denorm_min() { + return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat()); + } +}; + +METAL_FUNC bool isnan(_MLX_BFloat16 x) { + return x != x; +} + +} // namespace metal + +#pragma METAL internals : disable + +#endif // defined(__HAVE_BFLOAT__) + +#include "gemm/bf16_math.h" diff --git a/candle-metal-kernels/src/gemm/bf16_math.h b/candle-metal-kernels/src/gemm/bf16_math.h new file mode 100644 index 00000000..e6133346 --- /dev/null +++ b/candle-metal-kernels/src/gemm/bf16_math.h @@ -0,0 +1,394 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "gemm/bf16.h" + +/////////////////////////////////////////////////////////////////////////////// +// Metal math for bfloat16 +/////////////////////////////////////////////////////////////////////////////// + +/* + +Following the Metal Shading Language Specification (Metal 3.1) + +"bfloat is an extended itypeing point type that only allows implicit conversion + to a type of greater itypeing point rank. While bfloat can be implicitly + converted to itype, it cannot be implicitly converted to half, and neither + itype nor half can be implicitly converted to bfloat." + +Further, as far as I can tell, the stdlib math/simd functions are not defined +for bfloat and calling with an argument of type bfloat will result in that +argument getting implicitly converted to itype which then returns an output +that is (likely) a itype which cannot be implicitly converted into a bfloat + +This leads to situations where +bfloat a = 5.0bf; +bfloat b = metal::abs(a); // this will throw an error since abs return itype +bfloat c = static_cast(metal::abs(a)); // this is fine + +For the moment, I will be adding overloaded instantiations of the math +functions to accordingly automatically handle the casting + +*/ + +#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \ + \ + METAL_FUNC otype abs(itype x) { \ + return static_cast(__metal_fabs(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype acos(itype x) { \ + return static_cast(__metal_acos(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype acosh(itype x) { \ + return static_cast(__metal_acosh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype asin(itype x) { \ + return static_cast(__metal_asin(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype asinh(itype x) { \ + return static_cast(__metal_asinh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype atan(itype y_over_x) { \ + return static_cast( \ + __metal_atan(static_cast(y_over_x), mfast)); \ + } \ + METAL_FUNC otype atan2(itype y, itype x) { \ + return static_cast( \ + __metal_atan2(static_cast(y), static_cast(x), mfast)); \ + } \ + METAL_FUNC otype atanh(itype x) { \ + return static_cast(__metal_atanh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype ceil(itype x) { \ + return static_cast(__metal_ceil(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype cos(itype x) { \ + return static_cast(__metal_cos(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype cosh(itype x) { \ + return static_cast(__metal_cosh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype cospi(itype x) { \ + return static_cast(__metal_cospi(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype divide(itype x, itype y) { \ + return static_cast( \ + __metal_divide(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype exp(itype x) { \ + return static_cast(__metal_exp(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype exp10(itype x) { \ + return static_cast(__metal_exp10(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype exp2(itype x) { \ + return static_cast(__metal_exp2(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype fabs(itype x) { \ + return static_cast(__metal_fabs(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype fdim(itype x, itype y) { \ + ctype t = static_cast(x - y); \ + return static_cast(select(t, ctype(0), t < ctype(0) || x == y)); \ + } \ + METAL_FUNC otype floor(itype x) { \ + return static_cast(__metal_floor(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype fma(itype x, itype y, itype z) { \ + return static_cast(__metal_fma( \ + static_cast(x), static_cast(y), static_cast(z))); \ + } \ + METAL_FUNC otype fmax(itype x, itype y) { \ + return static_cast( \ + __metal_fmax(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype fmax3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmax3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmedian3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype fmin(itype x, itype y) { \ + return static_cast( \ + __metal_fmin(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype fmin3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmin3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype fmod(itype x, itype y) { \ + return static_cast( \ + __metal_fmod(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype fract(itype x) { \ + return static_cast(__metal_fract(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype frexp(itype x, thread int& exp) { \ + return static_cast(__metal_frexp(static_cast(x), &exp)); \ + } \ + METAL_FUNC otype ldexp(itype x, int k) { \ + return static_cast(__metal_ldexp(static_cast(x), k, mfast)); \ + } \ + METAL_FUNC otype log(itype x) { \ + return static_cast(__metal_log(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype log10(itype x) { \ + return static_cast(__metal_log10(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype log2(itype x) { \ + return static_cast(__metal_log2(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype max(itype x, itype y) { \ + return static_cast( \ + __metal_fmax(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype max3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmax3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype median3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmedian3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype min(itype x, itype y) { \ + return static_cast( \ + __metal_fmin(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype min3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmin3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype nextafter(itype x, itype y) { \ + return static_cast( \ + __metal_nextafter(static_cast(x), static_cast(y))); \ + } \ + METAL_FUNC otype pow(itype x, itype y) { \ + return static_cast( \ + __metal_pow(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype powr(itype x, itype y) { \ + return static_cast( \ + __metal_powr(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype rint(itype x) { \ + return static_cast(__metal_rint(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype round(itype x) { \ + return static_cast(__metal_round(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype rsqrt(itype x) { \ + return static_cast(__metal_rsqrt(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sin(itype x) { \ + return static_cast(__metal_sin(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sinh(itype x) { \ + return static_cast(__metal_sinh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sinpi(itype x) { \ + return static_cast(__metal_sinpi(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sqrt(itype x) { \ + return static_cast(__metal_sqrt(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype tan(itype x) { \ + return static_cast(__metal_tan(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype tanh(itype x) { \ + return static_cast(__metal_tanh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype tanpi(itype x) { \ + return static_cast(__metal_tanpi(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype trunc(itype x) { \ + return static_cast(__metal_trunc(static_cast(x), mfast)); \ + } + +namespace metal { + +instantiate_metal_math_funcs( + bfloat16_t, + bfloat16_t, + float, + __METAL_MAYBE_FAST_MATH__); + +namespace fast { + +instantiate_metal_math_funcs( + bfloat16_t, + bfloat16_t, + float, + __METAL_FAST_MATH__); + +} // namespace fast + +namespace precise { + +instantiate_metal_math_funcs( + bfloat16_t, + bfloat16_t, + float, + __METAL_PRECISE_MATH__); + +} // namespace precise + +} // namespace metal + +/////////////////////////////////////////////////////////////////////////////// +// Metal simd for bfloat16 +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_metal_simd_comm_funcs( \ + itype, otype, ctype, itype_to_ctype, ctype_to_otype) \ + \ + METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \ + return ctype_to_otype( \ + __metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \ + } \ + \ + METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \ + return ctype_to_otype( \ + __metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_down( \ + itype data, itype filling_data, ushort delta, ushort modulo) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \ + itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_down( \ + itype data, itype filling_data, ushort delta) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \ + itype_to_ctype(data), \ + itype_to_ctype(filling_data), \ + delta, \ + __metal_get_simdgroup_size(ushort()))); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_up( \ + itype data, itype filling_data, ushort delta, ushort modulo) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \ + itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_up( \ + itype data, itype filling_data, ushort delta) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \ + itype_to_ctype(data), \ + itype_to_ctype(filling_data), \ + delta, \ + __metal_get_simdgroup_size(ushort()))); \ + } \ + \ + METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_down(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_up(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \ + } + +#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \ + \ + METAL_FUNC otype simd_max(itype data) { \ + return static_cast(__metal_simd_max(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_min(itype data) { \ + return static_cast(__metal_simd_min(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \ + return static_cast( \ + __metal_simd_prefix_exclusive_product(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \ + return static_cast( \ + __metal_simd_prefix_exclusive_sum(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \ + return static_cast( \ + __metal_simd_prefix_inclusive_product(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \ + return static_cast( \ + __metal_simd_prefix_inclusive_sum(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_product(itype data) { \ + return static_cast(__metal_simd_product(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_sum(itype data) { \ + return static_cast(__metal_simd_sum(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_xor(itype data) { \ + return static_cast(__metal_simd_xor(static_cast(data))); \ + } + +#if defined(__HAVE_BFLOAT__) + +#define bfloat16_to_uint16(x) as_type(x) +#define uint16_to_bfloat16(x) as_type(x) + +#else + +#define bfloat16_to_uint16(x) x.bits_ +#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat()) + +#endif + +namespace metal { + +instantiate_metal_simd_comm_funcs( + bfloat16_t, + bfloat16_t, + uint16_t, + bfloat16_to_uint16, + uint16_to_bfloat16); +instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float); + +} // namespace metal diff --git a/candle-metal-kernels/src/gemm/complex.h b/candle-metal-kernels/src/gemm/complex.h new file mode 100644 index 00000000..9cb27c5a --- /dev/null +++ b/candle-metal-kernels/src/gemm/complex.h @@ -0,0 +1,131 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +using namespace metal; + +struct complex64_t; + +template +static constexpr constant bool can_convert_to_complex64 = + !is_same_v && is_convertible_v; + +template +static constexpr constant bool can_convert_from_complex64 = + !is_same_v && + (is_convertible_v || is_convertible_v); + +struct complex64_t { + float real; + float imag; + + // Constructors + constexpr complex64_t(float real, float imag) : real(real), imag(imag){}; + + // Conversions to complex64_t + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) thread : real(x), imag(0) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) threadgroup : real(x), imag(0) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) device : real(x), imag(0) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) constant : real(x), imag(0) {} + + // Conversions from complex64_t + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const thread { + return static_cast(real); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const threadgroup { + return static_cast(real); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const device { + return static_cast(real); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const constant { + return static_cast(real); + } +}; + +constexpr complex64_t operator-(complex64_t x) { + return {-x.real, -x.imag}; +} + +constexpr bool operator>=(complex64_t a, complex64_t b) { + return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag); +} + +constexpr bool operator>(complex64_t a, complex64_t b) { + return (a.real > b.real) || (a.real == b.real && a.imag > b.imag); +} + +constexpr bool operator<=(complex64_t a, complex64_t b) { + return operator>=(b, a); +} + +constexpr bool operator<(complex64_t a, complex64_t b) { + return operator>(b, a); +} + +constexpr bool operator==(complex64_t a, complex64_t b) { + return a.real == b.real && a.imag == b.imag; +} + +constexpr complex64_t operator+(complex64_t a, complex64_t b) { + return {a.real + b.real, a.imag + b.imag}; +} + +constexpr complex64_t operator-(complex64_t a, complex64_t b) { + return {a.real - b.real, a.imag - b.imag}; +} + +constexpr complex64_t operator*(complex64_t a, complex64_t b) { + return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real}; +} + +constexpr complex64_t operator/(complex64_t a, complex64_t b) { + auto denom = b.real * b.real + b.imag * b.imag; + auto x = a.real * b.real + a.imag * b.imag; + auto y = a.imag * b.real - a.real * b.imag; + return {x / denom, y / denom}; +} + +constexpr complex64_t operator%(complex64_t a, complex64_t b) { + auto real = a.real - (b.real * static_cast(a.real / b.real)); + auto imag = a.imag - (b.imag * static_cast(a.imag / b.imag)); + if (real != 0 && (real < 0 != b.real < 0)) { + real += b.real; + } + if (imag != 0 && (imag < 0 != b.imag < 0)) { + imag += b.imag; + } + return {real, imag}; +} diff --git a/candle-metal-kernels/src/gemm/gemm.h b/candle-metal-kernels/src/gemm/gemm.h new file mode 100644 index 00000000..0c31f781 --- /dev/null +++ b/candle-metal-kernels/src/gemm/gemm.h @@ -0,0 +1,292 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "gemm/loader.h" +#include "gemm/mma.h" +#include "gemm/transforms.h" +#include "utils.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel class +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct LoopAlignment {}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct GEMMKernel { + STEEL_CONST short tgp_padding_a = 16 / sizeof(T); + STEEL_CONST short tgp_padding_b = 16 / sizeof(T); + STEEL_CONST short tgp_mem_size_a = + transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); + STEEL_CONST short tgp_mem_size_b = + transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); + STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; + + STEEL_CONST short tgp_size = WM * WN * 32; + + using loader_a_t = BlockLoader< + T, + transpose_a ? BK : BM, + transpose_a ? BM : BK, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + !transpose_a, + tgp_size>; + using loader_b_t = BlockLoader< + T, + transpose_b ? BN : BK, + transpose_b ? BK : BN, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + transpose_b, + tgp_size>; + using mma_t = BlockMMA< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + AccumType, + Epilogue>; + + /* Main kernel function */ + template + static METAL_FUNC void gemm_loop( + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + const int gemm_k_iterations, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + thread mma_t& mma_op, + thread const short& tgp_bm, + thread const short& tgp_bn, + thread const short& lbk, + LoopAlignment l = {}) { + // Appease the compiler + (void)l; + + short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + + short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + if (!K_aligned_) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + short2 tile_dims_A_last = + transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); + short2 tile_dims_B_last = + transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); + + loader_a.load_safe(tile_dims_A_last); + loader_b.load_safe(tile_dims_B_last); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device U* C [[buffer(2)]], + const constant GEMMParams* params [[buffer(3)]], + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Pacifying compiler + (void)lid; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + + A += transpose_a ? c_row : c_row * params->lda; + B += transpose_b ? c_col * params->ldb : c_col; + C += c_row * params->ldc + c_col; + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Loop tail + if (!K_aligned) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + + // Store results to device memory + mma_op.store_result(C, params->ldc); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; + + if (tgp_bm == BM && tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result(C, params->ldc); + return; + + } else if (tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); + return; + + } else if (tgp_bm == BM) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); + return; + + } else { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); + return; + } + } + } +}; + +} // namespace steel +} // namespace mlx diff --git a/candle-metal-kernels/src/gemm/host.h b/candle-metal-kernels/src/gemm/host.h new file mode 100644 index 00000000..3b2d550b --- /dev/null +++ b/candle-metal-kernels/src/gemm/host.h @@ -0,0 +1,5 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "params.h" diff --git a/candle-metal-kernels/src/gemm/kernels/steel_gemm.metal b/candle-metal-kernels/src/gemm/kernels/steel_gemm.metal new file mode 100644 index 00000000..5caf2726 --- /dev/null +++ b/candle-metal-kernels/src/gemm/kernels/steel_gemm.metal @@ -0,0 +1,89 @@ +// Copyright © 2024 Apple Inc. + +#include "gemm/bf16.h" +#include "gemm/gemm.h" + +using namespace metal; +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +template +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm( + const device T *A [[buffer(0)]], + const device T *B [[buffer(1)]], + device T *C [[buffer(2)]], + const constant GEMMParams* params [[buffer(3)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + + using gemm_kernel = GEMMKernel; + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Adjust for batch + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + C += params->batch_stride_c * tid.z; + + gemm_kernel::run( + A, B, C, + params, + As, Bs, + simd_lane_id, simd_group_id, tid, lid + ); +} + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel initializations +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ + template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \ + [[kernel]] void gemm( \ + const device itype *A [[buffer(0)]], \ + const device itype *B [[buffer(1)]], \ + device itype *C [[buffer(2)]], \ + const constant GEMMParams* params [[buffer(3)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) + +#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) + +instantiate_gemm_shapes_helper(float16, half, float16, half); +instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); + +instantiate_gemm_shapes_helper(float32, float, float32, float); diff --git a/candle-metal-kernels/src/gemm/kernels/steel_gemm_addmm.metal b/candle-metal-kernels/src/gemm/kernels/steel_gemm_addmm.metal new file mode 100644 index 00000000..b8e131f0 --- /dev/null +++ b/candle-metal-kernels/src/gemm/kernels/steel_gemm_addmm.metal @@ -0,0 +1,254 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" + +using namespace metal; +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +template > +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void addmm( + const device T *A [[buffer(0)]], + const device T *B [[buffer(1)]], + const device T *C [[buffer(2)]], + device T *D [[buffer(3)]], + const constant GEMMAddMMParams* params [[buffer(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + + // Pacifying compiler + (void)lid; + + using gemm_kernel = + GEMMKernel; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Adjust for batch + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + C += params->batch_stride_c * tid.z; + D += params->batch_stride_d * tid.z; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + + A += transpose_a ? c_row : c_row * params->lda; + B += transpose_b ? c_col * params->ldb : c_col; + C += c_row * params->ldc + c_col * params->fdc; + D += c_row * params->ldd + c_col; + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + const Epilogue epilogue_op(params->alpha, params->beta); + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Loop tail + if (!K_aligned) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + + // Store results to device memory + mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; + + if (tgp_bm == BM && tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op); + return; + + } else if (tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + return mma_op.store_result_safe( + D, params->ldd, + C, params->ldc, params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op); + + } else if (tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + return mma_op.store_result_safe( + D, params->ldd, + C, params->ldc, params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op); + + } else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + return mma_op.store_result_safe( + D, params->ldd, + C, params->ldc, params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op); + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel initializations +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \ + template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \ + [[kernel]] void addmm>( \ + const device itype *A [[buffer(0)]], \ + const device itype *B [[buffer(1)]], \ + const device itype *C [[buffer(2)]], \ + device itype *D [[buffer(3)]], \ + const constant GEMMAddMMParams* params [[buffer(4)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby) + +#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ + instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ + instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ + instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) + +#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) + +instantiate_gemm_shapes_helper(float16, half, float16, half); +instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); + +instantiate_gemm_shapes_helper(float32, float, float32, float); \ No newline at end of file diff --git a/candle-metal-kernels/src/gemm/kernels/steel_gemm_splitk.metal b/candle-metal-kernels/src/gemm/kernels/steel_gemm_splitk.metal new file mode 100644 index 00000000..873f5faf --- /dev/null +++ b/candle-metal-kernels/src/gemm/kernels/steel_gemm_splitk.metal @@ -0,0 +1,280 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" + +using namespace metal; +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +template +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk( + const device T *A [[buffer(0)]], + const device T *B [[buffer(1)]], + device U *C [[buffer(2)]], + const constant GEMMSpiltKParams* params [[buffer(3)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + + (void)lid; + + using gemm_kernel = GEMMKernel; + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + const int tid_x = tid.x; + const int tid_y = tid.y; + const int tid_z = tid.z; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const int k_start = params->split_k_partition_size * tid_z; + + A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda); + B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb); + C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K % BK; + + if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else if (tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else if (tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if ((tid_z + 1) == (params->split_k_partitions)) { + int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK; + if(!K_aligned || gemm_k_iter_remaining > 0) + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iter_remaining, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } + + if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + mma_op.store_result(C, params->ldc); + } else { + mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); + } +} + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel initializations +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ + template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \ + [[kernel]] void gemm_splitk( \ + const device itype *A [[buffer(0)]], \ + const device itype *B [[buffer(1)]], \ + device otype *C [[buffer(2)]], \ + const constant GEMMSpiltKParams* params [[buffer(3)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) + +#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) + +instantiate_gemm_shapes_helper(float16, half, float32, float); +instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float); + +instantiate_gemm_shapes_helper(float32, float, float32, float); + +/////////////////////////////////////////////////////////////////////////////// +// Split k accumulation kernel +/////////////////////////////////////////////////////////////////////////////// + +template > +[[kernel]] void gemm_splitk_accum( + const device AccT *C_split [[buffer(0)]], + device OutT *D [[buffer(1)]], + const constant int& k_partitions [[buffer(2)]], + const constant int& partition_stride [[buffer(3)]], + const constant int& ldd [[buffer(4)]], + uint2 gid [[thread_position_in_grid]]) { + + // Ajust D and C + D += gid.x + gid.y * ldd; + C_split += gid.x + gid.y * ldd; + + int offset = 0; + AccT out = 0; + + for(int i = 0; i < k_partitions; i++) { + out += C_split[offset]; + offset += partition_stride; + } + + // Write output + D[0] = Epilogue::apply(out); + +} + +template > +[[kernel]] void gemm_splitk_accum_axpby( + const device AccT *C_split [[buffer(0)]], + device OutT *D [[buffer(1)]], + const constant int& k_partitions [[buffer(2)]], + const constant int& partition_stride [[buffer(3)]], + const constant int& ldd [[buffer(4)]], + const device OutT *C [[buffer(5)]], + const constant int& ldc [[buffer(6)]], + const constant int& fdc [[buffer(7)]], + const constant float& alpha [[buffer(8)]], + const constant float& beta [[buffer(9)]], + uint2 gid [[thread_position_in_grid]]) { + + // Ajust D and C + C += gid.x * fdc + gid.y * ldc; + D += gid.x + gid.y * ldd; + C_split += gid.x + gid.y * ldd; + + int offset = 0; + AccT out = 0; + + for(int i = 0; i < k_partitions; i++) { + out += C_split[offset]; + offset += partition_stride; + } + + // Write output + Epilogue op(alpha, beta); + D[0] = op.apply(out, *C); + +} + +#define instantiate_accum(oname, otype, aname, atype) \ + template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname)]] \ + [[kernel]] void gemm_splitk_accum( \ + const device atype *C_split [[buffer(0)]], \ + device otype *D [[buffer(1)]], \ + const constant int& k_partitions [[buffer(2)]], \ + const constant int& partition_stride [[buffer(3)]], \ + const constant int& ldd [[buffer(4)]], \ + uint2 gid [[thread_position_in_grid]]); \ + template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \ + [[kernel]] void gemm_splitk_accum_axpby( \ + const device atype *C_split [[buffer(0)]], \ + device otype *D [[buffer(1)]], \ + const constant int& k_partitions [[buffer(2)]], \ + const constant int& partition_stride [[buffer(3)]], \ + const constant int& ldd [[buffer(4)]], \ + const device otype *C [[buffer(5)]], \ + const constant int& ldc [[buffer(6)]], \ + const constant int& fdc [[buffer(7)]], \ + const constant float& alpha [[buffer(8)]], \ + const constant float& beta [[buffer(9)]], \ + uint2 gid [[thread_position_in_grid]]); + +instantiate_accum(bfloat16, bfloat16_t, float32, float); +instantiate_accum(float16, half, float32, float); +instantiate_accum(float32, float, float32, float); \ No newline at end of file diff --git a/candle-metal-kernels/src/gemm/loader.h b/candle-metal-kernels/src/gemm/loader.h new file mode 100644 index 00000000..df61bda0 --- /dev/null +++ b/candle-metal-kernels/src/gemm/loader.h @@ -0,0 +1,125 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "utils2.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short alignment = 1, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoader { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + struct alignas(alignment * sizeof(T)) ReadVector { + uint8_t v[sizeof(T) * vec_size]; + }; + + /* Constructor */ + METAL_FUNC BlockLoader( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + *((threadgroup ReadVector*)(&dst[i * dst_ld])) = + *((const device ReadVector*)(&src[i * src_ld])); + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out uneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + +} // namespace steel +} // namespace mlx diff --git a/candle-metal-kernels/src/gemm/mma.h b/candle-metal-kernels/src/gemm/mma.h new file mode 100644 index 00000000..d472073e --- /dev/null +++ b/candle-metal-kernels/src/gemm/mma.h @@ -0,0 +1,264 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "gemm/transforms.h" +#include "utils.h" + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType = float, + typename Epilogue = TransformNone> +struct BlockMMA { + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = 8 * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = 8 * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / TM_stride; + // Warp tile size along N + STEEL_CONST short TN = BN / TN_stride; + + // Strides of A, B along reduction axis + STEEL_CONST short simd_stride_a = { + transpose_a ? TM_stride : TM_stride * lda_tgp}; + STEEL_CONST short simd_stride_b = { + transpose_b ? TN_stride * ldb_tgp : TN_stride}; + + // Jump between elements + STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1}; + STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1}; + + STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; + STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; + + // Simdgroup matrices + simdgroup_matrix Asimd[TM]; + simdgroup_matrix Bsimd[TN]; + simdgroup_matrix results[TM * TN] = { + simdgroup_matrix(0)}; + + // Offsets within threadgroup + const short tm; + const short tn; + + short sm; + short sn; + + short As_offset; + short Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { + // Determine thread position in simdgroup matrix + short qid = simd_lane_id / 4; + sm = (qid & 4) + (simd_lane_id / 2) % 4; + sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + + // Determine thread and simdgroup offset + As_offset = + transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); + Bs_offset = + transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + + // Iterate over BK in blocks of 8 + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += 8) { + simdgroup_barrier(mem_flags::mem_none); + + // Load elements from threadgroup A as simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + Asimd[i].thread_elements()[0] = + static_cast(As[i * simd_stride_a + 0]); + Asimd[i].thread_elements()[1] = + static_cast(As[i * simd_stride_a + jump_a]); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Load elements from threadgroup B as simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + Bsimd[j].thread_elements()[0] = + static_cast(Bs[j * simd_stride_b + 0]); + Bsimd[j].thread_elements()[1] = + static_cast(Bs[j * simd_stride_b + jump_b]); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Multiply and accumulate into result simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + short j_serp = (i % 2) ? (TN - 1 - j) : j; + + simdgroup_multiply_accumulate( + results[i * TN + j_serp], + Asimd[i], + Bsimd[j_serp], + results[i * TN + j_serp]); + } + } + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* C, const int ldc) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + tn + sn; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldc + (j * TN_stride); + + // Apply epilogue + U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; + + // Write out C + C[offset] = outs[0]; + C[offset + 1] = outs[1]; + } + } + } + + METAL_FUNC void + store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn); + dst_tile_dims -= short2(tn + sn, sm + tm); + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldc + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + C[offset] = Epilogue::apply(accum[0]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + C[offset + 1] = Epilogue::apply(accum[1]); + } + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + D += (sm + tm) * ldd + tn + sn; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + U outs[2] = { + epilogue_op.apply(accum[0], C[offset_c]), + epilogue_op.apply(accum[1], C[offset_c + fdc])}; + + // Write out D + D[offset_d] = outs[0]; + D[offset_d + 1] = outs[1]; + } + } + } + + METAL_FUNC void store_result_safe( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + D += (sm + tm) * ldd + tn + sn; + dst_tile_dims -= short2(tn + sn, sm + tm); + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + } + } + } + } + } +}; + +} // namespace steel +} // namespace mlx diff --git a/candle-metal-kernels/src/gemm/params.h b/candle-metal-kernels/src/gemm/params.h new file mode 100644 index 00000000..d7e4db04 --- /dev/null +++ b/candle-metal-kernels/src/gemm/params.h @@ -0,0 +1,79 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +/////////////////////////////////////////////////////////////////////////////// +// GEMM param classes +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +struct GEMMParams { + const int M; + const int N; + const int K; + + const int lda; + const int ldb; + const int ldc; + + const int tiles_n; + const int tiles_m; + + const int batch_stride_a; + const int batch_stride_b; + const int batch_stride_c; + + const int swizzle_log; + const int gemm_k_iterations_aligned; +}; + +struct GEMMSpiltKParams { + const int M; + const int N; + const int K; + + const int lda; + const int ldb; + const int ldc; + + const int tiles_n; + const int tiles_m; + + const int split_k_partitions; + const int split_k_partition_stride; + const int split_k_partition_size; + + const int gemm_k_iterations_aligned; +}; + +struct GEMMAddMMParams { + const int M; + const int N; + const int K; + + const int lda; + const int ldb; + const int ldc; + const int ldd; + + const int tiles_n; + const int tiles_m; + + const int batch_stride_a; + const int batch_stride_b; + const int batch_stride_c; + const int batch_stride_d; + + const int swizzle_log; + const int gemm_k_iterations_aligned; + + const float alpha; + const float beta; + + const int fdc; +}; + +} // namespace steel +} // namespace mlx diff --git a/candle-metal-kernels/src/gemm/steel_gemm.metallib b/candle-metal-kernels/src/gemm/steel_gemm.metallib new file mode 100644 index 00000000..499aee79 Binary files /dev/null and b/candle-metal-kernels/src/gemm/steel_gemm.metallib differ diff --git a/candle-metal-kernels/src/gemm/transforms.h b/candle-metal-kernels/src/gemm/transforms.h new file mode 100644 index 00000000..952ce784 --- /dev/null +++ b/candle-metal-kernels/src/gemm/transforms.h @@ -0,0 +1,63 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "utils.h" + +/////////////////////////////////////////////////////////////////////////////// +// Transforms and Epilogues +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT) { + return static_cast(x); + } +}; + +template +struct TransformAdd { + TransformAdd(const float, const float) {} + + static METAL_FUNC OutT apply(InT x, OutT c) { + return static_cast(x) + c; + } +}; + +template +struct TransformAxpby { + const float alpha; + const float beta; + + TransformAxpby(const float alpha_, const float beta_) + : alpha(alpha_), beta(beta_) {} + + METAL_FUNC OutT apply(InT x, OutT c) const { + return static_cast(x * alpha + (beta * c)); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +struct BlockSwizzle { + static METAL_FUNC int2 + swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { + const int tid_x = (tid.x) >> swizzle_log; + const int tid_y = + ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); + return int2(tid_x, tid_y); + } +}; + +} // namespace steel +} // namespace mlx diff --git a/candle-metal-kernels/src/gemm/utils.h b/candle-metal-kernels/src/gemm/utils.h new file mode 100644 index 00000000..b828baf9 --- /dev/null +++ b/candle-metal-kernels/src/gemm/utils.h @@ -0,0 +1,276 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include "gemm/bf16.h" +#include "gemm/complex.h" + +/////////////////////////////////////////////////////////////////////////////// +// Type limits utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct Limits { + static const constant U max = metal::numeric_limits::max(); + static const constant U min = metal::numeric_limits::min(); + static const constant U finite_max = metal::numeric_limits::max(); + static const constant U finite_min = metal::numeric_limits::min(); +}; + +#define instantiate_default_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = metal::numeric_limits::max(); \ + static constexpr constant type min = metal::numeric_limits::min(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + metal::numeric_limits::min(); \ + }; + +instantiate_default_limit(uint8_t); +instantiate_default_limit(uint16_t); +instantiate_default_limit(uint32_t); +instantiate_default_limit(uint64_t); +instantiate_default_limit(int8_t); +instantiate_default_limit(int16_t); +instantiate_default_limit(int32_t); +instantiate_default_limit(int64_t); + +#define instantiate_float_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = \ + metal::numeric_limits::infinity(); \ + static constexpr constant type min = \ + -metal::numeric_limits::infinity(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + -metal::numeric_limits::max(); \ + }; + +instantiate_float_limit(half); +instantiate_float_limit(float); +instantiate_float_limit(bfloat16_t); + +template <> +struct Limits { + static constexpr constant bool max = true; + static constexpr constant bool min = false; +}; + +/////////////////////////////////////////////////////////////////////////////// +// Indexing utils +/////////////////////////////////////////////////////////////////////////////// + +inline size_t elem_to_loc( + uint elem, + device const int* shape, + device const size_t* strides, + int ndim) { + size_t loc = 0; + for (int i = ndim - 1; i >= 0; --i) { + loc += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + return loc; +} + +inline size_t elem_to_loc( + uint elem, + constant const int* shape, + constant const size_t* strides, + int ndim) { + size_t loc = 0; + for (int i = ndim - 1; i >= 0; --i) { + loc += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + return loc; +} + +template +inline uint2 elem_to_loc_2_nd( + uint3 elem, + constant const int shape[NDIM], + constant const size_t a_strides[NDIM], + constant const size_t b_strides[NDIM]) { + uint2 loc = { + static_cast( + elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]), + static_cast( + elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])}; + for (int d = NDIM - 3; d >= 0; --d) { + uint l = elem.z % shape[d]; + loc.x += l * a_strides[d]; + loc.y += l * b_strides[d]; + elem.z /= shape[d]; + } + return loc; +} + +template +inline size_t elem_to_loc_nd( + uint3 elem, + constant const int shape[NDIM], + constant const size_t strides[NDIM]) { + size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2]; + for (int d = NDIM - 3; d >= 0; --d) { + loc += (elem.z % shape[d]) * strides[d]; + elem.z /= shape[d]; + } + return loc; +} + +inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) { + return elem * stride; +} + +inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) { + return elem.x * strides[1] + elem.y * strides[0]; +} + +inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) { + return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0]; +} + +// Non templated version to handle arbitrary dims +inline size_t elem_to_loc( + uint3 elem, + constant const int* shape, + constant const size_t* strides, + int ndim) { + size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2]; + for (int d = ndim - 3; d >= 0; --d) { + loc += (elem.z % shape[d]) * strides[d]; + elem.z /= shape[d]; + } + return loc; +} + +inline uint2 elem_to_loc_2_nd( + uint3 elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + int ndim) { + uint2 loc = { + static_cast( + elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), + static_cast( + elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])}; + for (int d = ndim - 3; d >= 0; --d) { + uint l = elem.z % shape[d]; + loc.x += l * a_strides[d]; + loc.y += l * b_strides[d]; + elem.z /= shape[d]; + } + return loc; +} + +template +inline uint elem_to_loc_nd( + uint elem, + device const int* shape, + device const size_t* strides); + +template <> +inline uint elem_to_loc_nd<1>( + uint elem, + device const int* shape, + device const size_t* strides) { + return (elem % shape[0]) * strides[0]; +} + +template <> +inline uint elem_to_loc_nd<2>( + uint elem, + device const int* shape, + device const size_t* strides) { + uint loc = (elem % shape[1]) * strides[1]; + elem /= shape[1]; + loc += (elem % shape[0]) * strides[0]; + return loc; +} + +template <> +inline uint elem_to_loc_nd<3>( + uint elem, + device const int* shape, + device const size_t* strides) { + uint loc = (elem % shape[2]) * strides[2]; + elem /= shape[2]; + loc += (elem % shape[1]) * strides[1]; + elem /= shape[1]; + loc += (elem % shape[0]) * strides[0]; + return loc; +} + +template <> +inline uint elem_to_loc_nd<4>( + uint elem, + device const int* shape, + device const size_t* strides) { + uint loc = (elem % shape[3]) * strides[3]; + elem /= shape[3]; + loc += (elem % shape[2]) * strides[2]; + elem /= shape[2]; + loc += (elem % shape[1]) * strides[1]; + elem /= shape[1]; + loc += (elem % shape[0]) * strides[0]; + return loc; +} + +/////////////////////////////////////////////////////////////////////////////// +// Calculation utils +/////////////////////////////////////////////////////////////////////////////// + +/** Compute ceil((float)N/(float)M) */ +inline size_t ceildiv(size_t N, size_t M) { + return (N + M - 1) / M; +} + +// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202 +inline float log1p(float x) { + float xp1 = 1.0f + x; + if (xp1 == Limits::max) { + return Limits::max; + } + if (xp1 == 1.0f) { + return x; + } + + return x * (metal::log(xp1) / (xp1 - 1.0f)); +} + +inline bfloat16_t log1p(bfloat16_t x) { + float xp1 = 1.0f + static_cast(x); + if (xp1 == Limits::max) { + return Limits::max; + } + if (xp1 == 1.0f) { + return x; + } + + return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); +} + +/////////////////////////////////////////////////////////////////////////////// +// SIMD shuffle ops +/////////////////////////////////////////////////////////////////////////////// + +inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) { + return as_type( + metal::simd_shuffle_down(as_type(data), delta)); +} + +inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) { + return as_type( + metal::simd_shuffle_down(as_type(data), delta)); +} + +inline bool simd_shuffle_down(bool data, uint16_t delta) { + return simd_shuffle_down(static_cast(data), delta); +} diff --git a/candle-metal-kernels/src/gemm/utils2.h b/candle-metal-kernels/src/gemm/utils2.h new file mode 100644 index 00000000..c43f3ef8 --- /dev/null +++ b/candle-metal-kernels/src/gemm/utils2.h @@ -0,0 +1,9 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include "gemm/host.h" + +#define STEEL_CONST static constant constexpr const +#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 2d27d230..a1dc8b02 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,6 +1,7 @@ use metal::{ Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, - Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, + Device, Function, FunctionConstantValues, Library, MTLDataType, MTLResourceOptions, MTLSize, + NSUInteger, }; use std::collections::HashMap; use std::ffi::c_void; @@ -16,6 +17,7 @@ const CONV: &str = include_str!("conv.metal"); const REDUCE: &str = include_str!("reduce.metal"); const RANDOM: &str = include_str!("random.metal"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); +const GEMM: &[u8] = include_bytes!("gemm/steel_gemm.metallib"); const QUANTIZED: &str = include_str!("quantized.metal"); /// Most kernels apply similarly across the tensors @@ -122,6 +124,7 @@ pub enum Source { Cast, Reduce, Mfa, + Gemm, Conv, Random, Quantized, @@ -248,6 +251,7 @@ impl Kernels { Source::Random => RANDOM, Source::Quantized => QUANTIZED, Source::Mfa => panic!("Invalid lib"), + Source::Gemm => panic!("Invalid lib"), } } @@ -271,6 +275,14 @@ impl Kernels { )) })? } + Source::Gemm => { + let source_data = GEMM; + device.new_library_with_data(source_data).map_err(|e| { + MetalKernelError::LoadLibraryError(format!( + "Candle metal requires macosx > 13.0 or higher, cannot load GEMM: {e}" + )) + })? + } source => { let source_content = self.get_library_source(source); device @@ -1230,6 +1242,34 @@ impl ConstantValues { } } +fn string_to_static_str(s: String) -> &'static str { + Box::leak(s.into_boxed_str()) +} + +use core::ffi::c_int; + +#[repr(C)] +#[derive(Debug)] +struct GEMMParams { + m: c_int, + n: c_int, + k: c_int, + + lda: c_int, + ldb: c_int, + ldc: c_int, + + tiles_n: c_int, + tiles_m: c_int, + + batch_stride_a: c_int, + batch_stride_b: c_int, + batch_stride_c: c_int, + + swizzle_log: c_int, + gemm_k_iterations_aligned: c_int, +} + #[allow(clippy::too_many_arguments)] pub fn call_gemm( device: &Device, @@ -1251,10 +1291,10 @@ pub fn call_gemm( let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; - let a_trans = if lhs_m1 == 1 && lhs_m2 == k { - false + let (a_trans, lda) = if lhs_m1 == 1 && lhs_m2 == k { + (false, k as c_int) } else if lhs_m1 == m && lhs_m2 == 1 { - true + (true, n as c_int) } else { return Err(MetalKernelError::MatMulNonContiguous { lhs_stride: lhs_stride.to_vec(), @@ -1262,10 +1302,10 @@ pub fn call_gemm( mnk: (m, n, k), })?; }; - let b_trans = if rhs_m1 == 1 && rhs_m2 == n { - false + let (b_trans, ldb) = if rhs_m1 == 1 && rhs_m2 == n { + (false, n as c_int) } else if rhs_m1 == k && rhs_m2 == 1 { - true + (true, k as c_int) } else { return Err(MetalKernelError::MatMulNonContiguous { lhs_stride: lhs_stride.to_vec(), @@ -1273,119 +1313,195 @@ pub fn call_gemm( mnk: (m, n, k), })?; }; - let d_trans = false; - let alpha = 1.0f32; - let beta = 0.0f32; - let batched = b > 1; - let fused_activation = false; - let fused_bias = false; - let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 { - let m_simd = 8; - let n_simd = 8; - let k_simd = 64; - let m_splits = 1; - let n_splits = 1; - (m_simd, n_simd, k_simd, m_splits, n_splits) - } else { - let m_simd = 40; - let n_simd = 40; - let k_simd = 32; - let m_splits = 1; - let n_splits = 1; - (m_simd, n_simd, k_simd, m_splits, n_splits) - }; - let constants = Some(ConstantValues::new(vec![ - (0, Value::USize(m)), - (1, Value::USize(n)), - (2, Value::USize(k)), - (10, Value::Bool(a_trans)), - (11, Value::Bool(b_trans)), - (13, Value::Bool(d_trans)), - (20, Value::F32(alpha)), - (21, Value::F32(beta)), - (100, Value::Bool(batched)), - (101, Value::Bool(fused_activation)), - // Garbage - (102, Value::Bool(false)), - (103, Value::Bool(false)), - (113, Value::Bool(false)), - (50_000, Value::Bool(false)), - // End garbage - (200, Value::U16(m_simd)), - (201, Value::U16(n_simd)), - (202, Value::U16(k_simd)), - (210, Value::U16(m_splits)), - (211, Value::U16(n_splits)), - (50_001, Value::Bool(fused_bias)), - ])); - let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?; - let m_group = m_simd * m_splits; - let n_group = n_simd * n_splits; - - let a_block_length = m_group * k_simd; - let b_block_length = k_simd * n_group; - - let mut block_elements = a_block_length + b_block_length; - if (m % 8 != 0) && (n % 8 != 0) { - let c_block_length = m_group * n_group; - block_elements = std::cmp::max(c_block_length, block_elements) - } - if fused_bias { - if d_trans { - block_elements = std::cmp::max(block_elements, m_group); - } else { - block_elements = std::cmp::max(block_elements, n_group); - } - } - let bytes = match name { - "sgemm" => 4, - "hgemm" => 2, + // let d_trans = false; + // let alpha = 1.0f32; + // let beta = 0.0f32; + // let batched = b > 1; + // let fused_activation = false; + // let fused_bias = false; + // let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 { + // let m_simd = 8; + // let n_simd = 8; + // let k_simd = 64; + // let m_splits = 1; + // let n_splits = 1; + // (m_simd, n_simd, k_simd, m_splits, n_splits) + // } else { + // let m_simd = 40; + // let n_simd = 40; + // let k_simd = 32; + // let m_splits = 1; + // let n_splits = 1; + // (m_simd, n_simd, k_simd, m_splits, n_splits) + // }; + // let constants = Some(ConstantValues::new(vec![ + // (0, Value::USize(m)), + // (1, Value::USize(n)), + // (2, Value::USize(k)), + // (10, Value::Bool(a_trans)), + // (11, Value::Bool(b_trans)), + // (13, Value::Bool(d_trans)), + // (20, Value::F32(alpha)), + // (21, Value::F32(beta)), + // (100, Value::Bool(batched)), + // (101, Value::Bool(fused_activation)), + // // Garbage + // (102, Value::Bool(false)), + // (103, Value::Bool(false)), + // (113, Value::Bool(false)), + // (50_000, Value::Bool(false)), + // // End garbage + // (200, Value::U16(m_simd)), + // (201, Value::U16(n_simd)), + // (202, Value::U16(k_simd)), + // (210, Value::U16(m_splits)), + // (211, Value::U16(n_splits)), + // (50_001, Value::Bool(fused_bias)), + // ])); + let a_trans_name = if a_trans { "t" } else { "n" }; + let b_trans_name = if b_trans { "t" } else { "n" }; + let (iname, oname) = match name { + "sgemm" => ("float32", "float32"), + "hgemm" => ("float16", "float16"), + "bgemm" => ("bfloat16", "bfloat16"), other => { return Err(MetalKernelError::LoadLibraryError(format!( "{other} is not a valid kernel for gemm" - ))); + ))) } }; - let block_bytes = block_elements * bytes; + let mut bm = 32; + let mut bn = 32; + let mut bk = 16; + let wm = 2; + let wn = 2; + if b * m * n >= 1 << 20 { + if !a_trans && b_trans { + bm = 64; + bn = if oname == "float32" { 64 } else { 32 }; + bk = if oname == "float32" { 16 } else { 32 }; + } else { + bm = 64; + bn = 64; + } + } + let mnaligned = if m % bm == 0 && n % bn == 0 { + "taligned" + } else { + "naligned" + }; + let kaligned = if k % bk == 0 { "taligned" } else { "naligned" }; + // let bytes = match &name[..] { + // "sgemm" => 4, + // "hgemm" => 2, + // other => { + // return Err(MetalKernelError::LoadLibraryError(format!( + // "{other} is not a valid kernel for gemm" + // ))); + // } + // }; + let name = format!("steel_gemm_{a_trans_name}{b_trans_name}_{iname}_{oname}_bm{bm}_bn{bn}_bk{bk}_wm{wm}_wn{wn}_MN_{mnaligned}_K_{kaligned}"); + let name = string_to_static_str(name); + let pipeline = kernels.load_pipeline(device, Source::Gemm, name)?; + // let m_group = m_simd * m_splits; + // let n_group = n_simd * n_splits; + + // let a_block_length = m_group * k_simd; + // let b_block_length = k_simd * n_group; + + // let mut block_elements = a_block_length + b_block_length; + // if (m % 8 != 0) && (n % 8 != 0) { + // let c_block_length = m_group * n_group; + // block_elements = std::cmp::max(c_block_length, block_elements) + // } + // if fused_bias { + // if d_trans { + // block_elements = std::cmp::max(block_elements, m_group); + // } else { + // block_elements = std::cmp::max(block_elements, n_group); + // } + // } + // let block_bytes = block_elements * bytes; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - encoder.set_threadgroup_memory_length(0, block_bytes.into()); + // encoder.set_threadgroup_memory_length(0, block_bytes.into()); + + let batch_stride_a: i32 = if lhs_stride.len() > 2 { + lhs_stride[lhs_stride.len() - 3] as i32 + } else { + 0 + }; + let batch_stride_b: i32 = if rhs_stride.len() > 2 { + rhs_stride[rhs_stride.len() - 3] as i32 + } else { + 0 + }; + let batch_stride_c = (m * n) as i32; + + let swizzle_log = 0; + let tiles_n = ((n + bn - 1) / bn) as c_int; + let tiles_m = ((m + bm - 1) / bm) as c_int; + + let params = GEMMParams { + m: m as c_int, + n: n as c_int, + k: k as c_int, + lda, + ldb, + ldc: n as c_int, + tiles_m, + tiles_n, + batch_stride_a, + batch_stride_b, + batch_stride_c, + swizzle_log, + gemm_k_iterations_aligned: (k / bk) as c_int, + }; + let params_buffer = device.new_buffer_with_data( + ¶ms as *const GEMMParams as *const c_void, + core::mem::size_of::() as u64, + MTLResourceOptions::StorageModeShared, + ); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); encoder.set_buffer(2, Some(output), 0); + encoder.set_buffer(3, Some(¶ms_buffer), 0); // TODO Tensor D let grid_z = b; - if batched { - let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize; - let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize; - let byte_stride_c = m * n * bytes as usize; - // TODO byte_stride_d - let byte_stride_d = 0; + // if batched { + // let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize; + // let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize; + // let byte_stride_c = m * n * bytes as usize; + // // TODO byte_stride_d + // let byte_stride_d = 0; - let buffer: Vec = vec![ - byte_stride_a as _, - byte_stride_b as _, - byte_stride_c as _, - byte_stride_d as _, - ]; - encoder.set_bytes( - 10, - (buffer.len() * core::mem::size_of::()) as NSUInteger, - buffer.as_ptr() as *const NSUInteger as *const c_void, - ); - } + // let buffer: Vec = vec![ + // byte_stride_a as _, + // byte_stride_b as _, + // byte_stride_c as _, + // byte_stride_d as _, + // ]; + // // encoder.set_bytes( + // // 10, + // // (buffer.len() * core::mem::size_of::()) as NSUInteger, + // // buffer.as_ptr() as *const NSUInteger as *const c_void, + // // ); + // } + let tile = 1 << swizzle_log; + let tm = (tiles_m + tile - 1) / tile; + let tn = tiles_n * tile; let grid_size = MTLSize { - width: divide(n, n_group.into()), - height: divide(m, m_group.into()), + width: tn as u64, + height: tm as u64, depth: grid_z as NSUInteger, }; let group_size = MTLSize { - width: 32 * (m_splits as u64) * (n_splits as u64), - height: 1, - depth: 1, + width: 32, + height: wn, + depth: wm, }; encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);