mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
1 Commits
0.8.3
...
bf16_metal
Author | SHA1 | Date | |
---|---|---|---|
e2bf0adc2a |
2
candle-metal-kernels/compile.sh
Normal file
2
candle-metal-kernels/compile.sh
Normal file
@ -0,0 +1,2 @@
|
||||
xcrun metal -c src/gemm/kernels/steel_gemm.metal -I src/
|
||||
xcrun metallib steel_gemm.air -o src/gemm/steel_gemm.metallib
|
317
candle-metal-kernels/src/gemm/bf16.h
Normal file
317
candle-metal-kernels/src/gemm/bf16.h
Normal file
@ -0,0 +1,317 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
|
||||
typedef bfloat bfloat16_t;
|
||||
|
||||
#else
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Helpers
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
|
||||
// Check for nan
|
||||
if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
|
||||
_fp_encoding_traits<float>::inf_mask) {
|
||||
return uint16_t(as_type<uint32_t>(0x7FC0));
|
||||
}
|
||||
// Take bits
|
||||
uint32_t float_bits = as_type<uint32_t>(x);
|
||||
|
||||
// Round to nearest even
|
||||
float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
|
||||
|
||||
// Take upper 16 bits
|
||||
return float_bits >> 16;
|
||||
}
|
||||
|
||||
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
|
||||
// Upper 16 bits are the data and lower 16 bits are 0s
|
||||
return as_type<float>((uint32_t)x << 16);
|
||||
}
|
||||
|
||||
struct _MLX_BFloat16;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_to_bfloat =
|
||||
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_from_bfloat =
|
||||
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat struct
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct _MLX_BFloat16 {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Constructors
|
||||
uint16_t bits_;
|
||||
_MLX_BFloat16() thread = default;
|
||||
_MLX_BFloat16() threadgroup = default;
|
||||
_MLX_BFloat16() device = default;
|
||||
_MLX_BFloat16() constant = default;
|
||||
|
||||
struct bits_to_bfloat_struct {};
|
||||
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
|
||||
return bits_to_bfloat_struct();
|
||||
}
|
||||
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
|
||||
: bits_(bits) {}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Conversions to bfloat
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) device
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC _MLX_BFloat16(T x) constant
|
||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Conversions from bfloat
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const thread {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const threadgroup {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const device {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||
constexpr METAL_FUNC operator T() const constant {
|
||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat operators
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Unary ops
|
||||
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
|
||||
return -static_cast<float>(x);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Binary operators
|
||||
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
||||
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
}
|
||||
|
||||
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
||||
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
} \
|
||||
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Arithmetic Operators
|
||||
#define bfloat_binop(_op_, _operator_) \
|
||||
bfloat_binop_base( \
|
||||
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, float, float, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, float, half, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
|
||||
|
||||
bfloat_binop(+, operator+);
|
||||
bfloat_binop(-, operator-);
|
||||
bfloat_binop(*, operator*);
|
||||
bfloat_binop(/, operator/);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Comparison ops
|
||||
#define bfloat_compop(__op__, __operator__) \
|
||||
bfloat_binop_base( \
|
||||
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
||||
|
||||
bfloat_compop(>, operator>);
|
||||
bfloat_compop(<, operator<);
|
||||
bfloat_compop(>=, operator>=);
|
||||
bfloat_compop(<=, operator<=);
|
||||
bfloat_compop(==, operator==);
|
||||
bfloat_compop(!=, operator!=);
|
||||
|
||||
#undef bfloat_compop
|
||||
#undef bfloat_binop_base
|
||||
#undef bfloat_binop_helper
|
||||
#undef bfloat_binop
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Inplace Operators
|
||||
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
|
||||
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
||||
addr_space _MLX_BFloat16& lhs, itype rhs) { \
|
||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||
return lhs; \
|
||||
} \
|
||||
constexpr METAL_FUNC addr_space itype& __operator__( \
|
||||
addr_space itype& lhs, _MLX_BFloat16 rhs) { \
|
||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||
return lhs; \
|
||||
}
|
||||
|
||||
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
|
||||
|
||||
#define bfloat_inplace_op(itype) \
|
||||
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
|
||||
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
|
||||
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
|
||||
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
|
||||
|
||||
bfloat_inplace_op(float);
|
||||
bfloat_inplace_op(half);
|
||||
bfloat_inplace_op(int16_t);
|
||||
bfloat_inplace_op(int32_t);
|
||||
bfloat_inplace_op(int64_t);
|
||||
bfloat_inplace_op(uint16_t);
|
||||
bfloat_inplace_op(uint32_t);
|
||||
bfloat_inplace_op(uint64_t);
|
||||
|
||||
#undef bfloat_inplace_op_helper
|
||||
#undef bfloat_inplace_op_addr_space_helper
|
||||
#undef bfloat_inplace_op
|
||||
|
||||
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
|
||||
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
||||
addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
|
||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||
return lhs; \
|
||||
}
|
||||
|
||||
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, device); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, thread); \
|
||||
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
|
||||
|
||||
bfloat_inplace_op_addr_space_helper(+, operator+=);
|
||||
bfloat_inplace_op_addr_space_helper(-, operator-=);
|
||||
bfloat_inplace_op_addr_space_helper(*, operator*=);
|
||||
bfloat_inplace_op_addr_space_helper(/, operator/=);
|
||||
|
||||
#undef bfloat_inplace_op_helper
|
||||
#undef bfloat_inplace_op_addr_space_helper
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat typedef
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
typedef struct _MLX_BFloat16 bfloat16_t;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat numeric limits
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#pragma METAL internals : enable
|
||||
|
||||
namespace metal {
|
||||
|
||||
template <>
|
||||
struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
|
||||
static constexpr constant int digits = 8;
|
||||
static constexpr constant int digits10 = 2;
|
||||
static constexpr constant int max_digits10 = 4;
|
||||
static constexpr constant int radix = 2;
|
||||
static constexpr constant int min_exponent = -125;
|
||||
static constexpr constant int min_exponent10 = -37;
|
||||
static constexpr constant int max_exponent = 128;
|
||||
static constexpr constant int max_exponent10 = 38;
|
||||
|
||||
static constexpr bfloat16_t min() {
|
||||
return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t lowest() {
|
||||
return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t max() {
|
||||
return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t epsilon() {
|
||||
return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t round_error() {
|
||||
return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t infinity() {
|
||||
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t quiet_NaN() {
|
||||
return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t signaling_NaN() {
|
||||
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
static constexpr bfloat16_t denorm_min() {
|
||||
return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat());
|
||||
}
|
||||
};
|
||||
|
||||
METAL_FUNC bool isnan(_MLX_BFloat16 x) {
|
||||
return x != x;
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
||||
#pragma METAL internals : disable
|
||||
|
||||
#endif // defined(__HAVE_BFLOAT__)
|
||||
|
||||
#include "gemm/bf16_math.h"
|
394
candle-metal-kernels/src/gemm/bf16_math.h
Normal file
394
candle-metal-kernels/src/gemm/bf16_math.h
Normal file
@ -0,0 +1,394 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "gemm/bf16.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Metal math for bfloat16
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
|
||||
Following the Metal Shading Language Specification (Metal 3.1)
|
||||
|
||||
"bfloat is an extended itypeing point type that only allows implicit conversion
|
||||
to a type of greater itypeing point rank. While bfloat can be implicitly
|
||||
converted to itype, it cannot be implicitly converted to half, and neither
|
||||
itype nor half can be implicitly converted to bfloat."
|
||||
|
||||
Further, as far as I can tell, the stdlib math/simd functions are not defined
|
||||
for bfloat and calling with an argument of type bfloat will result in that
|
||||
argument getting implicitly converted to itype which then returns an output
|
||||
that is (likely) a itype which cannot be implicitly converted into a bfloat
|
||||
|
||||
This leads to situations where
|
||||
bfloat a = 5.0bf;
|
||||
bfloat b = metal::abs(a); // this will throw an error since abs return itype
|
||||
bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
|
||||
|
||||
For the moment, I will be adding overloaded instantiations of the math
|
||||
functions to accordingly automatically handle the casting
|
||||
|
||||
*/
|
||||
|
||||
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
|
||||
\
|
||||
METAL_FUNC otype abs(itype x) { \
|
||||
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype acos(itype x) { \
|
||||
return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype acosh(itype x) { \
|
||||
return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype asin(itype x) { \
|
||||
return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype asinh(itype x) { \
|
||||
return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype atan(itype y_over_x) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_atan(static_cast<ctype>(y_over_x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype atan2(itype y, itype x) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype atanh(itype x) { \
|
||||
return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype ceil(itype x) { \
|
||||
return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype cos(itype x) { \
|
||||
return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype cosh(itype x) { \
|
||||
return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype cospi(itype x) { \
|
||||
return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype divide(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype exp(itype x) { \
|
||||
return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype exp10(itype x) { \
|
||||
return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype exp2(itype x) { \
|
||||
return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fabs(itype x) { \
|
||||
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fdim(itype x, itype y) { \
|
||||
ctype t = static_cast<ctype>(x - y); \
|
||||
return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
|
||||
} \
|
||||
METAL_FUNC otype floor(itype x) { \
|
||||
return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fma(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fma( \
|
||||
static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
|
||||
} \
|
||||
METAL_FUNC otype fmax(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmax3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmedian3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmin(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmin3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmod(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fract(itype x) { \
|
||||
return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype frexp(itype x, thread int& exp) { \
|
||||
return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
|
||||
} \
|
||||
METAL_FUNC otype ldexp(itype x, int k) { \
|
||||
return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype log(itype x) { \
|
||||
return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype log10(itype x) { \
|
||||
return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype log2(itype x) { \
|
||||
return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype max(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype max3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmax3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype median3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmedian3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype min(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype min3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmin3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype nextafter(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
|
||||
} \
|
||||
METAL_FUNC otype pow(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype powr(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype rint(itype x) { \
|
||||
return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype round(itype x) { \
|
||||
return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype rsqrt(itype x) { \
|
||||
return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sin(itype x) { \
|
||||
return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sinh(itype x) { \
|
||||
return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sinpi(itype x) { \
|
||||
return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sqrt(itype x) { \
|
||||
return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype tan(itype x) { \
|
||||
return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype tanh(itype x) { \
|
||||
return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype tanpi(itype x) { \
|
||||
return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype trunc(itype x) { \
|
||||
return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
|
||||
}
|
||||
|
||||
namespace metal {
|
||||
|
||||
instantiate_metal_math_funcs(
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
float,
|
||||
__METAL_MAYBE_FAST_MATH__);
|
||||
|
||||
namespace fast {
|
||||
|
||||
instantiate_metal_math_funcs(
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
float,
|
||||
__METAL_FAST_MATH__);
|
||||
|
||||
} // namespace fast
|
||||
|
||||
namespace precise {
|
||||
|
||||
instantiate_metal_math_funcs(
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
float,
|
||||
__METAL_PRECISE_MATH__);
|
||||
|
||||
} // namespace precise
|
||||
|
||||
} // namespace metal
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Metal simd for bfloat16
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_metal_simd_comm_funcs( \
|
||||
itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
|
||||
\
|
||||
METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
||||
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
||||
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
||||
itype data, itype filling_data, ushort delta) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
||||
itype_to_ctype(data), \
|
||||
itype_to_ctype(filling_data), \
|
||||
delta, \
|
||||
__metal_get_simdgroup_size(ushort()))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
||||
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
||||
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
||||
itype data, itype filling_data, ushort delta) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
||||
itype_to_ctype(data), \
|
||||
itype_to_ctype(filling_data), \
|
||||
delta, \
|
||||
__metal_get_simdgroup_size(ushort()))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
|
||||
}
|
||||
|
||||
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
|
||||
\
|
||||
METAL_FUNC otype simd_max(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_min(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_product(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_sum(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_xor(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||
}
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
|
||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
||||
|
||||
#else
|
||||
|
||||
#define bfloat16_to_uint16(x) x.bits_
|
||||
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
|
||||
|
||||
#endif
|
||||
|
||||
namespace metal {
|
||||
|
||||
instantiate_metal_simd_comm_funcs(
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
uint16_t,
|
||||
bfloat16_to_uint16,
|
||||
uint16_to_bfloat16);
|
||||
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
|
||||
|
||||
} // namespace metal
|
131
candle-metal-kernels/src/gemm/complex.h
Normal file
131
candle-metal-kernels/src/gemm/complex.h
Normal file
@ -0,0 +1,131 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct complex64_t;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_to_complex64 =
|
||||
!is_same_v<T, complex64_t> && is_convertible_v<T, float>;
|
||||
|
||||
template <typename T>
|
||||
static constexpr constant bool can_convert_from_complex64 =
|
||||
!is_same_v<T, complex64_t> &&
|
||||
(is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
|
||||
|
||||
struct complex64_t {
|
||||
float real;
|
||||
float imag;
|
||||
|
||||
// Constructors
|
||||
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
|
||||
|
||||
// Conversions to complex64_t
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) thread : real(x), imag(0) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) device : real(x), imag(0) {}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||
constexpr complex64_t(T x) constant : real(x), imag(0) {}
|
||||
|
||||
// Conversions from complex64_t
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const thread {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const threadgroup {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const device {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||
constexpr operator T() const constant {
|
||||
return static_cast<T>(real);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr complex64_t operator-(complex64_t x) {
|
||||
return {-x.real, -x.imag};
|
||||
}
|
||||
|
||||
constexpr bool operator>=(complex64_t a, complex64_t b) {
|
||||
return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
|
||||
}
|
||||
|
||||
constexpr bool operator>(complex64_t a, complex64_t b) {
|
||||
return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
|
||||
}
|
||||
|
||||
constexpr bool operator<=(complex64_t a, complex64_t b) {
|
||||
return operator>=(b, a);
|
||||
}
|
||||
|
||||
constexpr bool operator<(complex64_t a, complex64_t b) {
|
||||
return operator>(b, a);
|
||||
}
|
||||
|
||||
constexpr bool operator==(complex64_t a, complex64_t b) {
|
||||
return a.real == b.real && a.imag == b.imag;
|
||||
}
|
||||
|
||||
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
||||
return {a.real + b.real, a.imag + b.imag};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
||||
return {a.real - b.real, a.imag - b.imag};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
||||
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator/(complex64_t a, complex64_t b) {
|
||||
auto denom = b.real * b.real + b.imag * b.imag;
|
||||
auto x = a.real * b.real + a.imag * b.imag;
|
||||
auto y = a.imag * b.real - a.real * b.imag;
|
||||
return {x / denom, y / denom};
|
||||
}
|
||||
|
||||
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
||||
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
||||
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
||||
if (real != 0 && (real < 0 != b.real < 0)) {
|
||||
real += b.real;
|
||||
}
|
||||
if (imag != 0 && (imag < 0 != b.imag < 0)) {
|
||||
imag += b.imag;
|
||||
}
|
||||
return {real, imag};
|
||||
}
|
292
candle-metal-kernels/src/gemm/gemm.h
Normal file
292
candle-metal-kernels/src/gemm/gemm.h
Normal file
@ -0,0 +1,292 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "gemm/loader.h"
|
||||
#include "gemm/mma.h"
|
||||
#include "gemm/transforms.h"
|
||||
#include "utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel class
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <bool M_aligned, bool N_aligned, bool K_aligned>
|
||||
struct LoopAlignment {};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<U, AccumType>>
|
||||
struct GEMMKernel {
|
||||
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
|
||||
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
|
||||
STEEL_CONST short tgp_mem_size_a =
|
||||
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
||||
STEEL_CONST short tgp_mem_size_b =
|
||||
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
||||
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
||||
|
||||
STEEL_CONST short tgp_size = WM * WN * 32;
|
||||
|
||||
using loader_a_t = BlockLoader<
|
||||
T,
|
||||
transpose_a ? BK : BM,
|
||||
transpose_a ? BM : BK,
|
||||
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
||||
!transpose_a,
|
||||
tgp_size>;
|
||||
using loader_b_t = BlockLoader<
|
||||
T,
|
||||
transpose_b ? BN : BK,
|
||||
transpose_b ? BK : BN,
|
||||
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
||||
transpose_b,
|
||||
tgp_size>;
|
||||
using mma_t = BlockMMA<
|
||||
T,
|
||||
U,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
||||
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
||||
AccumType,
|
||||
Epilogue>;
|
||||
|
||||
/* Main kernel function */
|
||||
template <bool M_aligned, bool N_aligned, bool K_aligned_>
|
||||
static METAL_FUNC void gemm_loop(
|
||||
threadgroup T* As [[threadgroup(0)]],
|
||||
threadgroup T* Bs [[threadgroup(1)]],
|
||||
const int gemm_k_iterations,
|
||||
thread loader_a_t& loader_a,
|
||||
thread loader_b_t& loader_b,
|
||||
thread mma_t& mma_op,
|
||||
thread const short& tgp_bm,
|
||||
thread const short& tgp_bn,
|
||||
thread const short& lbk,
|
||||
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
|
||||
// Appease the compiler
|
||||
(void)l;
|
||||
|
||||
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
|
||||
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
if (M_aligned) {
|
||||
loader_a.load_unsafe();
|
||||
} else {
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
}
|
||||
|
||||
if (N_aligned) {
|
||||
loader_b.load_unsafe();
|
||||
} else {
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
if (!K_aligned_) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
short2 tile_dims_A_last =
|
||||
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
||||
short2 tile_dims_B_last =
|
||||
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A_last);
|
||||
loader_b.load_safe(tile_dims_B_last);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
}
|
||||
|
||||
/* Main kernel function */
|
||||
static METAL_FUNC void run(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device U* C [[buffer(2)]],
|
||||
const constant GEMMParams* params [[buffer(3)]],
|
||||
threadgroup T* As [[threadgroup(0)]],
|
||||
threadgroup T* Bs [[threadgroup(1)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
|
||||
A += transpose_a ? c_row : c_row * params->lda;
|
||||
B += transpose_b ? c_col * params->ldb : c_col;
|
||||
C += c_row * params->ldc + c_col;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Loop tail
|
||||
if (!K_aligned) {
|
||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(C, params->ldc);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
|
||||
if (tgp_bm == BM && tgp_bn == BN) {
|
||||
gemm_loop<true, true, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result(C, params->ldc);
|
||||
return;
|
||||
|
||||
} else if (tgp_bn == BN) {
|
||||
gemm_loop<false, true, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_loop<true, false, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
|
||||
} else {
|
||||
gemm_loop<false, false, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
5
candle-metal-kernels/src/gemm/host.h
Normal file
5
candle-metal-kernels/src/gemm/host.h
Normal file
@ -0,0 +1,5 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "params.h"
|
89
candle-metal-kernels/src/gemm/kernels/steel_gemm.metal
Normal file
89
candle-metal-kernels/src/gemm/kernels/steel_gemm.metal
Normal file
@ -0,0 +1,89 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "gemm/bf16.h"
|
||||
#include "gemm/gemm.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
|
||||
const device T *A [[buffer(0)]],
|
||||
const device T *B [[buffer(1)]],
|
||||
device T *C [[buffer(2)]],
|
||||
const constant GEMMParams* params [[buffer(3)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Adjust for batch
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
C += params->batch_stride_c * tid.z;
|
||||
|
||||
gemm_kernel::run(
|
||||
A, B, C,
|
||||
params,
|
||||
As, Bs,
|
||||
simd_lane_id, simd_group_id, tid, lid
|
||||
);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
device itype *C [[buffer(2)]], \
|
||||
const constant GEMMParams* params [[buffer(3)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
254
candle-metal-kernels/src/gemm/kernels/steel_gemm_addmm.metal
Normal file
254
candle-metal-kernels/src/gemm/kernels/steel_gemm_addmm.metal
Normal file
@ -0,0 +1,254 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformAdd<T, AccumType>>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void addmm(
|
||||
const device T *A [[buffer(0)]],
|
||||
const device T *B [[buffer(1)]],
|
||||
const device T *C [[buffer(2)]],
|
||||
device T *D [[buffer(3)]],
|
||||
const constant GEMMAddMMParams* params [[buffer(4)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
using gemm_kernel =
|
||||
GEMMKernel<T, T, BM, BN, BK, WM, WN,
|
||||
transpose_a, transpose_b,
|
||||
MN_aligned, K_aligned,
|
||||
AccumType, Epilogue>;
|
||||
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Adjust for batch
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
C += params->batch_stride_c * tid.z;
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
|
||||
A += transpose_a ? c_row : c_row * params->lda;
|
||||
B += transpose_b ? c_col * params->ldb : c_col;
|
||||
C += c_row * params->ldc + c_col * params->fdc;
|
||||
D += c_row * params->ldd + c_col;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
const Epilogue epilogue_op(params->alpha, params->beta);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Loop tail
|
||||
if (!K_aligned) {
|
||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
|
||||
if (tgp_bm == BM && tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, true, K_aligned>{});
|
||||
|
||||
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
|
||||
return;
|
||||
|
||||
} else if (tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, true, K_aligned>{});
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D, params->ldd,
|
||||
C, params->ldc, params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, false, K_aligned>{});
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D, params->ldd,
|
||||
C, params->ldc, params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
|
||||
} else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, K_aligned>{});
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D, params->ldd,
|
||||
C, params->ldc, params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \
|
||||
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \
|
||||
[[kernel]] void addmm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, float, epilogue<itype, float>>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
const device itype *C [[buffer(2)]], \
|
||||
device itype *D [[buffer(3)]], \
|
||||
const constant GEMMAddMMParams* params [[buffer(4)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby)
|
||||
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
280
candle-metal-kernels/src/gemm/kernels/steel_gemm_splitk.metal
Normal file
280
candle-metal-kernels/src/gemm/kernels/steel_gemm_splitk.metal
Normal file
@ -0,0 +1,280 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk(
|
||||
const device T *A [[buffer(0)]],
|
||||
const device T *B [[buffer(1)]],
|
||||
device U *C [[buffer(2)]],
|
||||
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
(void)lid;
|
||||
|
||||
using gemm_kernel = GEMMKernel<T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
const int tid_x = tid.x;
|
||||
const int tid_y = tid.y;
|
||||
const int tid_z = tid.z;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const int k_start = params->split_k_partition_size * tid_z;
|
||||
|
||||
A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda);
|
||||
B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb);
|
||||
C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short leftover_bk = params->K % BK;
|
||||
|
||||
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, true, true>{});
|
||||
} else if (tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, true, true>{});
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, false, true>{});
|
||||
} else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, true>{});
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if ((tid_z + 1) == (params->split_k_partitions)) {
|
||||
int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK;
|
||||
if(!K_aligned || gemm_k_iter_remaining > 0)
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iter_remaining,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, K_aligned>{});
|
||||
}
|
||||
|
||||
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
mma_op.store_result(C, params->ldc);
|
||||
} else {
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
||||
[[kernel]] void gemm_splitk<itype, otype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
device otype *C [[buffer(2)]], \
|
||||
const constant GEMMSpiltKParams* params [[buffer(3)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float32, float);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Split k accumulation kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename AccT,
|
||||
typename OutT,
|
||||
typename Epilogue = TransformNone<OutT, AccT>>
|
||||
[[kernel]] void gemm_splitk_accum(
|
||||
const device AccT *C_split [[buffer(0)]],
|
||||
device OutT *D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
// Ajust D and C
|
||||
D += gid.x + gid.y * ldd;
|
||||
C_split += gid.x + gid.y * ldd;
|
||||
|
||||
int offset = 0;
|
||||
AccT out = 0;
|
||||
|
||||
for(int i = 0; i < k_partitions; i++) {
|
||||
out += C_split[offset];
|
||||
offset += partition_stride;
|
||||
}
|
||||
|
||||
// Write output
|
||||
D[0] = Epilogue::apply(out);
|
||||
|
||||
}
|
||||
|
||||
template <typename AccT,
|
||||
typename OutT,
|
||||
typename Epilogue = TransformAxpby<OutT, AccT>>
|
||||
[[kernel]] void gemm_splitk_accum_axpby(
|
||||
const device AccT *C_split [[buffer(0)]],
|
||||
device OutT *D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
const device OutT *C [[buffer(5)]],
|
||||
const constant int& ldc [[buffer(6)]],
|
||||
const constant int& fdc [[buffer(7)]],
|
||||
const constant float& alpha [[buffer(8)]],
|
||||
const constant float& beta [[buffer(9)]],
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
// Ajust D and C
|
||||
C += gid.x * fdc + gid.y * ldc;
|
||||
D += gid.x + gid.y * ldd;
|
||||
C_split += gid.x + gid.y * ldd;
|
||||
|
||||
int offset = 0;
|
||||
AccT out = 0;
|
||||
|
||||
for(int i = 0; i < k_partitions; i++) {
|
||||
out += C_split[offset];
|
||||
offset += partition_stride;
|
||||
}
|
||||
|
||||
// Write output
|
||||
Epilogue op(alpha, beta);
|
||||
D[0] = op.apply(out, *C);
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_accum(oname, otype, aname, atype) \
|
||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname)]] \
|
||||
[[kernel]] void gemm_splitk_accum<atype, otype>( \
|
||||
const device atype *C_split [[buffer(0)]], \
|
||||
device otype *D [[buffer(1)]], \
|
||||
const constant int& k_partitions [[buffer(2)]], \
|
||||
const constant int& partition_stride [[buffer(3)]], \
|
||||
const constant int& ldd [[buffer(4)]], \
|
||||
uint2 gid [[thread_position_in_grid]]); \
|
||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \
|
||||
[[kernel]] void gemm_splitk_accum_axpby<atype, otype>( \
|
||||
const device atype *C_split [[buffer(0)]], \
|
||||
device otype *D [[buffer(1)]], \
|
||||
const constant int& k_partitions [[buffer(2)]], \
|
||||
const constant int& partition_stride [[buffer(3)]], \
|
||||
const constant int& ldd [[buffer(4)]], \
|
||||
const device otype *C [[buffer(5)]], \
|
||||
const constant int& ldc [[buffer(6)]], \
|
||||
const constant int& fdc [[buffer(7)]], \
|
||||
const constant float& alpha [[buffer(8)]], \
|
||||
const constant float& beta [[buffer(9)]], \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
instantiate_accum(bfloat16, bfloat16_t, float32, float);
|
||||
instantiate_accum(float16, half, float32, float);
|
||||
instantiate_accum(float32, float, float32, float);
|
125
candle-metal-kernels/src/gemm/loader.h
Normal file
125
candle-metal-kernels/src/gemm/loader.h
Normal file
@ -0,0 +1,125 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "utils2.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BROWS,
|
||||
short BCOLS,
|
||||
short dst_ld,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short alignment = 1,
|
||||
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
||||
short TCOLS = BCOLS / n_reads,
|
||||
short TROWS = tgp_size / TCOLS>
|
||||
struct BlockLoader {
|
||||
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
||||
STEEL_CONST short vec_size = n_reads;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
const int tile_stride;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
struct alignas(alignment * sizeof(T)) ReadVector {
|
||||
uint8_t v[sizeof(T) * vec_size];
|
||||
};
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockLoader(
|
||||
const device T* src_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld + bj) {}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
|
||||
*((const device ReadVector*)(&src[i * src_ld]));
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||
src_tile_dim = src_tile_dim - short2(bj, bi);
|
||||
|
||||
// Skip loading if thread has no valid reads
|
||||
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Use fast thread memory for bound checks
|
||||
bool tmp_idx[vec_size];
|
||||
T tmp_val[vec_size];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
// Make sure tmp_idx only contains valid indices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
||||
}
|
||||
|
||||
// Read valid indices into tmp_val
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
||||
}
|
||||
|
||||
// Copy values to threadgroup memory
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = tmp_val[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
src += tile_stride;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
264
candle-metal-kernels/src/gemm/mma.h
Normal file
264
candle-metal-kernels/src/gemm/mma.h
Normal file
@ -0,0 +1,264 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "gemm/transforms.h"
|
||||
#include "utils.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MMA helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
short lda_tgp,
|
||||
short ldb_tgp,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformNone<U, AccumType>>
|
||||
struct BlockMMA {
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TM_stride = 8 * WM;
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TN_stride = 8 * WN;
|
||||
|
||||
// Warp tile size along M
|
||||
STEEL_CONST short TM = BM / TM_stride;
|
||||
// Warp tile size along N
|
||||
STEEL_CONST short TN = BN / TN_stride;
|
||||
|
||||
// Strides of A, B along reduction axis
|
||||
STEEL_CONST short simd_stride_a = {
|
||||
transpose_a ? TM_stride : TM_stride * lda_tgp};
|
||||
STEEL_CONST short simd_stride_b = {
|
||||
transpose_b ? TN_stride * ldb_tgp : TN_stride};
|
||||
|
||||
// Jump between elements
|
||||
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
|
||||
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
|
||||
|
||||
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
|
||||
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
|
||||
|
||||
// Simdgroup matrices
|
||||
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
||||
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
||||
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
||||
simdgroup_matrix<AccumType, 8, 8>(0)};
|
||||
|
||||
// Offsets within threadgroup
|
||||
const short tm;
|
||||
const short tn;
|
||||
|
||||
short sm;
|
||||
short sn;
|
||||
|
||||
short As_offset;
|
||||
short Bs_offset;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockMMA(
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
||||
// Determine thread position in simdgroup matrix
|
||||
short qid = simd_lane_id / 4;
|
||||
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
|
||||
// Determine thread and simdgroup offset
|
||||
As_offset =
|
||||
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
|
||||
Bs_offset =
|
||||
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
|
||||
}
|
||||
|
||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
||||
// Adjust for simdgroup and thread location
|
||||
As += As_offset;
|
||||
Bs += Bs_offset;
|
||||
|
||||
// Iterate over BK in blocks of 8
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short kk = 0; kk < BK; kk += 8) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Load elements from threadgroup A as simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
Asimd[i].thread_elements()[0] =
|
||||
static_cast<AccumType>(As[i * simd_stride_a + 0]);
|
||||
Asimd[i].thread_elements()[1] =
|
||||
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Load elements from threadgroup B as simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
Bsimd[j].thread_elements()[0] =
|
||||
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
|
||||
Bsimd[j].thread_elements()[1] =
|
||||
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Multiply and accumulate into result simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
short j_serp = (i % 2) ? (TN - 1 - j) : j;
|
||||
|
||||
simdgroup_multiply_accumulate(
|
||||
results[i * TN + j_serp],
|
||||
Asimd[i],
|
||||
Bsimd[j_serp],
|
||||
results[i * TN + j_serp]);
|
||||
}
|
||||
}
|
||||
|
||||
// Progress to next simdgroup tile
|
||||
As += tile_stride_a;
|
||||
Bs += tile_stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device U* C, const int ldc) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + tn + sn;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
|
||||
|
||||
// Write out C
|
||||
C[offset] = outs[0];
|
||||
C[offset + 1] = outs[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn);
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * TN_stride < dst_tile_dims.x) {
|
||||
C[offset] = Epilogue::apply(accum[0]);
|
||||
}
|
||||
|
||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
D += (sm + tm) * ldd + tn + sn;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
U outs[2] = {
|
||||
epilogue_op.apply(accum[0], C[offset_c]),
|
||||
epilogue_op.apply(accum[1], C[offset_c + fdc])};
|
||||
|
||||
// Write out D
|
||||
D[offset_d] = outs[0];
|
||||
D[offset_d + 1] = outs[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_result_safe(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
short2 dst_tile_dims,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
D += (sm + tm) * ldd + tn + sn;
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * TN_stride < dst_tile_dims.x) {
|
||||
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
|
||||
}
|
||||
|
||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
79
candle-metal-kernels/src/gemm/params.h
Normal file
79
candle-metal-kernels/src/gemm/params.h
Normal file
@ -0,0 +1,79 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM param classes
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
struct GEMMParams {
|
||||
const int M;
|
||||
const int N;
|
||||
const int K;
|
||||
|
||||
const int lda;
|
||||
const int ldb;
|
||||
const int ldc;
|
||||
|
||||
const int tiles_n;
|
||||
const int tiles_m;
|
||||
|
||||
const int batch_stride_a;
|
||||
const int batch_stride_b;
|
||||
const int batch_stride_c;
|
||||
|
||||
const int swizzle_log;
|
||||
const int gemm_k_iterations_aligned;
|
||||
};
|
||||
|
||||
struct GEMMSpiltKParams {
|
||||
const int M;
|
||||
const int N;
|
||||
const int K;
|
||||
|
||||
const int lda;
|
||||
const int ldb;
|
||||
const int ldc;
|
||||
|
||||
const int tiles_n;
|
||||
const int tiles_m;
|
||||
|
||||
const int split_k_partitions;
|
||||
const int split_k_partition_stride;
|
||||
const int split_k_partition_size;
|
||||
|
||||
const int gemm_k_iterations_aligned;
|
||||
};
|
||||
|
||||
struct GEMMAddMMParams {
|
||||
const int M;
|
||||
const int N;
|
||||
const int K;
|
||||
|
||||
const int lda;
|
||||
const int ldb;
|
||||
const int ldc;
|
||||
const int ldd;
|
||||
|
||||
const int tiles_n;
|
||||
const int tiles_m;
|
||||
|
||||
const int batch_stride_a;
|
||||
const int batch_stride_b;
|
||||
const int batch_stride_c;
|
||||
const int batch_stride_d;
|
||||
|
||||
const int swizzle_log;
|
||||
const int gemm_k_iterations_aligned;
|
||||
|
||||
const float alpha;
|
||||
const float beta;
|
||||
|
||||
const int fdc;
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
BIN
candle-metal-kernels/src/gemm/steel_gemm.metallib
Normal file
BIN
candle-metal-kernels/src/gemm/steel_gemm.metallib
Normal file
Binary file not shown.
63
candle-metal-kernels/src/gemm/transforms.h
Normal file
63
candle-metal-kernels/src/gemm/transforms.h
Normal file
@ -0,0 +1,63 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Transforms and Epilogues
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformNone {
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x, OutT) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformAdd {
|
||||
TransformAdd(const float, const float) {}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x, OutT c) {
|
||||
return static_cast<OutT>(x) + c;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformAxpby {
|
||||
const float alpha;
|
||||
const float beta;
|
||||
|
||||
TransformAxpby(const float alpha_, const float beta_)
|
||||
: alpha(alpha_), beta(beta_) {}
|
||||
|
||||
METAL_FUNC OutT apply(InT x, OutT c) const {
|
||||
return static_cast<OutT>(x * alpha + (beta * c));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AccumHelper {
|
||||
typedef float accum_type;
|
||||
};
|
||||
|
||||
struct BlockSwizzle {
|
||||
static METAL_FUNC int2
|
||||
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
|
||||
const int tid_x = (tid.x) >> swizzle_log;
|
||||
const int tid_y =
|
||||
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
|
||||
return int2(tid_x, tid_y);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
276
candle-metal-kernels/src/gemm/utils.h
Normal file
276
candle-metal-kernels/src/gemm/utils.h
Normal file
@ -0,0 +1,276 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_math>
|
||||
#include "gemm/bf16.h"
|
||||
#include "gemm/complex.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Type limits utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename U>
|
||||
struct Limits {
|
||||
static const constant U max = metal::numeric_limits<U>::max();
|
||||
static const constant U min = metal::numeric_limits<U>::min();
|
||||
static const constant U finite_max = metal::numeric_limits<U>::max();
|
||||
static const constant U finite_min = metal::numeric_limits<U>::min();
|
||||
};
|
||||
|
||||
#define instantiate_default_limit(type) \
|
||||
template <> \
|
||||
struct Limits<type> { \
|
||||
static constexpr constant type max = metal::numeric_limits<type>::max(); \
|
||||
static constexpr constant type min = metal::numeric_limits<type>::min(); \
|
||||
static constexpr constant type finite_max = \
|
||||
metal::numeric_limits<type>::max(); \
|
||||
static constexpr constant type finite_min = \
|
||||
metal::numeric_limits<type>::min(); \
|
||||
};
|
||||
|
||||
instantiate_default_limit(uint8_t);
|
||||
instantiate_default_limit(uint16_t);
|
||||
instantiate_default_limit(uint32_t);
|
||||
instantiate_default_limit(uint64_t);
|
||||
instantiate_default_limit(int8_t);
|
||||
instantiate_default_limit(int16_t);
|
||||
instantiate_default_limit(int32_t);
|
||||
instantiate_default_limit(int64_t);
|
||||
|
||||
#define instantiate_float_limit(type) \
|
||||
template <> \
|
||||
struct Limits<type> { \
|
||||
static constexpr constant type max = \
|
||||
metal::numeric_limits<type>::infinity(); \
|
||||
static constexpr constant type min = \
|
||||
-metal::numeric_limits<type>::infinity(); \
|
||||
static constexpr constant type finite_max = \
|
||||
metal::numeric_limits<type>::max(); \
|
||||
static constexpr constant type finite_min = \
|
||||
-metal::numeric_limits<type>::max(); \
|
||||
};
|
||||
|
||||
instantiate_float_limit(half);
|
||||
instantiate_float_limit(float);
|
||||
instantiate_float_limit(bfloat16_t);
|
||||
|
||||
template <>
|
||||
struct Limits<bool> {
|
||||
static constexpr constant bool max = true;
|
||||
static constexpr constant bool min = false;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline size_t elem_to_loc(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc(
|
||||
uint elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t a_strides[NDIM],
|
||||
constant const size_t b_strides[NDIM]) {
|
||||
uint2 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline size_t elem_to_loc_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t strides[NDIM]) {
|
||||
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
loc += (elem.z % shape[d]) * strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) {
|
||||
return elem * stride;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
|
||||
return elem.x * strides[1] + elem.y * strides[0];
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
|
||||
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
||||
}
|
||||
|
||||
// Non templated version to handle arbitrary dims
|
||||
inline size_t elem_to_loc(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
loc += (elem.z % shape[d]) * strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
int ndim) {
|
||||
uint2 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint elem_to_loc_nd(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides);
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<1>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
return (elem % shape[0]) * strides[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<2>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
uint loc = (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<3>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
uint loc = (elem % shape[2]) * strides[2];
|
||||
elem /= shape[2];
|
||||
loc += (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<4>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
uint loc = (elem % shape[3]) * strides[3];
|
||||
elem /= shape[3];
|
||||
loc += (elem % shape[2]) * strides[2];
|
||||
elem /= shape[2];
|
||||
loc += (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
return loc;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Calculation utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/** Compute ceil((float)N/(float)M) */
|
||||
inline size_t ceildiv(size_t N, size_t M) {
|
||||
return (N + M - 1) / M;
|
||||
}
|
||||
|
||||
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
|
||||
inline float log1p(float x) {
|
||||
float xp1 = 1.0f + x;
|
||||
if (xp1 == Limits<float>::max) {
|
||||
return Limits<float>::max;
|
||||
}
|
||||
if (xp1 == 1.0f) {
|
||||
return x;
|
||||
}
|
||||
|
||||
return x * (metal::log(xp1) / (xp1 - 1.0f));
|
||||
}
|
||||
|
||||
inline bfloat16_t log1p(bfloat16_t x) {
|
||||
float xp1 = 1.0f + static_cast<float>(x);
|
||||
if (xp1 == Limits<float>::max) {
|
||||
return Limits<bfloat16_t>::max;
|
||||
}
|
||||
if (xp1 == 1.0f) {
|
||||
return x;
|
||||
}
|
||||
|
||||
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// SIMD shuffle ops
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
|
||||
return as_type<uint64_t>(
|
||||
metal::simd_shuffle_down(as_type<uint2>(data), delta));
|
||||
}
|
||||
|
||||
inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
|
||||
return as_type<int64_t>(
|
||||
metal::simd_shuffle_down(as_type<uint2>(data), delta));
|
||||
}
|
||||
|
||||
inline bool simd_shuffle_down(bool data, uint16_t delta) {
|
||||
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
|
||||
}
|
9
candle-metal-kernels/src/gemm/utils2.h
Normal file
9
candle-metal-kernels/src/gemm/utils2.h
Normal file
@ -0,0 +1,9 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include "gemm/host.h"
|
||||
|
||||
#define STEEL_CONST static constant constexpr const
|
||||
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
@ -1,6 +1,7 @@
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLResourceOptions, MTLSize,
|
||||
NSUInteger,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
@ -16,6 +17,7 @@ const CONV: &str = include_str!("conv.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const GEMM: &[u8] = include_bytes!("gemm/steel_gemm.metallib");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
|
||||
/// Most kernels apply similarly across the tensors
|
||||
@ -122,6 +124,7 @@ pub enum Source {
|
||||
Cast,
|
||||
Reduce,
|
||||
Mfa,
|
||||
Gemm,
|
||||
Conv,
|
||||
Random,
|
||||
Quantized,
|
||||
@ -248,6 +251,7 @@ impl Kernels {
|
||||
Source::Random => RANDOM,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
Source::Gemm => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
|
||||
@ -271,6 +275,14 @@ impl Kernels {
|
||||
))
|
||||
})?
|
||||
}
|
||||
Source::Gemm => {
|
||||
let source_data = GEMM;
|
||||
device.new_library_with_data(source_data).map_err(|e| {
|
||||
MetalKernelError::LoadLibraryError(format!(
|
||||
"Candle metal requires macosx > 13.0 or higher, cannot load GEMM: {e}"
|
||||
))
|
||||
})?
|
||||
}
|
||||
source => {
|
||||
let source_content = self.get_library_source(source);
|
||||
device
|
||||
@ -1230,6 +1242,34 @@ impl ConstantValues {
|
||||
}
|
||||
}
|
||||
|
||||
fn string_to_static_str(s: String) -> &'static str {
|
||||
Box::leak(s.into_boxed_str())
|
||||
}
|
||||
|
||||
use core::ffi::c_int;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug)]
|
||||
struct GEMMParams {
|
||||
m: c_int,
|
||||
n: c_int,
|
||||
k: c_int,
|
||||
|
||||
lda: c_int,
|
||||
ldb: c_int,
|
||||
ldc: c_int,
|
||||
|
||||
tiles_n: c_int,
|
||||
tiles_m: c_int,
|
||||
|
||||
batch_stride_a: c_int,
|
||||
batch_stride_b: c_int,
|
||||
batch_stride_c: c_int,
|
||||
|
||||
swizzle_log: c_int,
|
||||
gemm_k_iterations_aligned: c_int,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_gemm(
|
||||
device: &Device,
|
||||
@ -1251,10 +1291,10 @@ pub fn call_gemm(
|
||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
let a_trans = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
false
|
||||
let (a_trans, lda) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
(false, k as c_int)
|
||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||
true
|
||||
(true, n as c_int)
|
||||
} else {
|
||||
return Err(MetalKernelError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
@ -1262,10 +1302,10 @@ pub fn call_gemm(
|
||||
mnk: (m, n, k),
|
||||
})?;
|
||||
};
|
||||
let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
false
|
||||
let (b_trans, ldb) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
(false, n as c_int)
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
true
|
||||
(true, k as c_int)
|
||||
} else {
|
||||
return Err(MetalKernelError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
@ -1273,119 +1313,195 @@ pub fn call_gemm(
|
||||
mnk: (m, n, k),
|
||||
})?;
|
||||
};
|
||||
let d_trans = false;
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
let batched = b > 1;
|
||||
let fused_activation = false;
|
||||
let fused_bias = false;
|
||||
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
||||
let m_simd = 8;
|
||||
let n_simd = 8;
|
||||
let k_simd = 64;
|
||||
let m_splits = 1;
|
||||
let n_splits = 1;
|
||||
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||
} else {
|
||||
let m_simd = 40;
|
||||
let n_simd = 40;
|
||||
let k_simd = 32;
|
||||
let m_splits = 1;
|
||||
let n_splits = 1;
|
||||
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||
};
|
||||
let constants = Some(ConstantValues::new(vec![
|
||||
(0, Value::USize(m)),
|
||||
(1, Value::USize(n)),
|
||||
(2, Value::USize(k)),
|
||||
(10, Value::Bool(a_trans)),
|
||||
(11, Value::Bool(b_trans)),
|
||||
(13, Value::Bool(d_trans)),
|
||||
(20, Value::F32(alpha)),
|
||||
(21, Value::F32(beta)),
|
||||
(100, Value::Bool(batched)),
|
||||
(101, Value::Bool(fused_activation)),
|
||||
// Garbage
|
||||
(102, Value::Bool(false)),
|
||||
(103, Value::Bool(false)),
|
||||
(113, Value::Bool(false)),
|
||||
(50_000, Value::Bool(false)),
|
||||
// End garbage
|
||||
(200, Value::U16(m_simd)),
|
||||
(201, Value::U16(n_simd)),
|
||||
(202, Value::U16(k_simd)),
|
||||
(210, Value::U16(m_splits)),
|
||||
(211, Value::U16(n_splits)),
|
||||
(50_001, Value::Bool(fused_bias)),
|
||||
]));
|
||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
|
||||
let m_group = m_simd * m_splits;
|
||||
let n_group = n_simd * n_splits;
|
||||
|
||||
let a_block_length = m_group * k_simd;
|
||||
let b_block_length = k_simd * n_group;
|
||||
|
||||
let mut block_elements = a_block_length + b_block_length;
|
||||
if (m % 8 != 0) && (n % 8 != 0) {
|
||||
let c_block_length = m_group * n_group;
|
||||
block_elements = std::cmp::max(c_block_length, block_elements)
|
||||
}
|
||||
if fused_bias {
|
||||
if d_trans {
|
||||
block_elements = std::cmp::max(block_elements, m_group);
|
||||
} else {
|
||||
block_elements = std::cmp::max(block_elements, n_group);
|
||||
}
|
||||
}
|
||||
let bytes = match name {
|
||||
"sgemm" => 4,
|
||||
"hgemm" => 2,
|
||||
// let d_trans = false;
|
||||
// let alpha = 1.0f32;
|
||||
// let beta = 0.0f32;
|
||||
// let batched = b > 1;
|
||||
// let fused_activation = false;
|
||||
// let fused_bias = false;
|
||||
// let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
||||
// let m_simd = 8;
|
||||
// let n_simd = 8;
|
||||
// let k_simd = 64;
|
||||
// let m_splits = 1;
|
||||
// let n_splits = 1;
|
||||
// (m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||
// } else {
|
||||
// let m_simd = 40;
|
||||
// let n_simd = 40;
|
||||
// let k_simd = 32;
|
||||
// let m_splits = 1;
|
||||
// let n_splits = 1;
|
||||
// (m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||
// };
|
||||
// let constants = Some(ConstantValues::new(vec![
|
||||
// (0, Value::USize(m)),
|
||||
// (1, Value::USize(n)),
|
||||
// (2, Value::USize(k)),
|
||||
// (10, Value::Bool(a_trans)),
|
||||
// (11, Value::Bool(b_trans)),
|
||||
// (13, Value::Bool(d_trans)),
|
||||
// (20, Value::F32(alpha)),
|
||||
// (21, Value::F32(beta)),
|
||||
// (100, Value::Bool(batched)),
|
||||
// (101, Value::Bool(fused_activation)),
|
||||
// // Garbage
|
||||
// (102, Value::Bool(false)),
|
||||
// (103, Value::Bool(false)),
|
||||
// (113, Value::Bool(false)),
|
||||
// (50_000, Value::Bool(false)),
|
||||
// // End garbage
|
||||
// (200, Value::U16(m_simd)),
|
||||
// (201, Value::U16(n_simd)),
|
||||
// (202, Value::U16(k_simd)),
|
||||
// (210, Value::U16(m_splits)),
|
||||
// (211, Value::U16(n_splits)),
|
||||
// (50_001, Value::Bool(fused_bias)),
|
||||
// ]));
|
||||
let a_trans_name = if a_trans { "t" } else { "n" };
|
||||
let b_trans_name = if b_trans { "t" } else { "n" };
|
||||
let (iname, oname) = match name {
|
||||
"sgemm" => ("float32", "float32"),
|
||||
"hgemm" => ("float16", "float16"),
|
||||
"bgemm" => ("bfloat16", "bfloat16"),
|
||||
other => {
|
||||
return Err(MetalKernelError::LoadLibraryError(format!(
|
||||
"{other} is not a valid kernel for gemm"
|
||||
)));
|
||||
)))
|
||||
}
|
||||
};
|
||||
let block_bytes = block_elements * bytes;
|
||||
let mut bm = 32;
|
||||
let mut bn = 32;
|
||||
let mut bk = 16;
|
||||
let wm = 2;
|
||||
let wn = 2;
|
||||
if b * m * n >= 1 << 20 {
|
||||
if !a_trans && b_trans {
|
||||
bm = 64;
|
||||
bn = if oname == "float32" { 64 } else { 32 };
|
||||
bk = if oname == "float32" { 16 } else { 32 };
|
||||
} else {
|
||||
bm = 64;
|
||||
bn = 64;
|
||||
}
|
||||
}
|
||||
let mnaligned = if m % bm == 0 && n % bn == 0 {
|
||||
"taligned"
|
||||
} else {
|
||||
"naligned"
|
||||
};
|
||||
let kaligned = if k % bk == 0 { "taligned" } else { "naligned" };
|
||||
// let bytes = match &name[..] {
|
||||
// "sgemm" => 4,
|
||||
// "hgemm" => 2,
|
||||
// other => {
|
||||
// return Err(MetalKernelError::LoadLibraryError(format!(
|
||||
// "{other} is not a valid kernel for gemm"
|
||||
// )));
|
||||
// }
|
||||
// };
|
||||
let name = format!("steel_gemm_{a_trans_name}{b_trans_name}_{iname}_{oname}_bm{bm}_bn{bn}_bk{bk}_wm{wm}_wn{wn}_MN_{mnaligned}_K_{kaligned}");
|
||||
let name = string_to_static_str(name);
|
||||
let pipeline = kernels.load_pipeline(device, Source::Gemm, name)?;
|
||||
// let m_group = m_simd * m_splits;
|
||||
// let n_group = n_simd * n_splits;
|
||||
|
||||
// let a_block_length = m_group * k_simd;
|
||||
// let b_block_length = k_simd * n_group;
|
||||
|
||||
// let mut block_elements = a_block_length + b_block_length;
|
||||
// if (m % 8 != 0) && (n % 8 != 0) {
|
||||
// let c_block_length = m_group * n_group;
|
||||
// block_elements = std::cmp::max(c_block_length, block_elements)
|
||||
// }
|
||||
// if fused_bias {
|
||||
// if d_trans {
|
||||
// block_elements = std::cmp::max(block_elements, m_group);
|
||||
// } else {
|
||||
// block_elements = std::cmp::max(block_elements, n_group);
|
||||
// }
|
||||
// }
|
||||
// let block_bytes = block_elements * bytes;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||
// encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||
|
||||
let batch_stride_a: i32 = if lhs_stride.len() > 2 {
|
||||
lhs_stride[lhs_stride.len() - 3] as i32
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let batch_stride_b: i32 = if rhs_stride.len() > 2 {
|
||||
rhs_stride[rhs_stride.len() - 3] as i32
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let batch_stride_c = (m * n) as i32;
|
||||
|
||||
let swizzle_log = 0;
|
||||
let tiles_n = ((n + bn - 1) / bn) as c_int;
|
||||
let tiles_m = ((m + bm - 1) / bm) as c_int;
|
||||
|
||||
let params = GEMMParams {
|
||||
m: m as c_int,
|
||||
n: n as c_int,
|
||||
k: k as c_int,
|
||||
lda,
|
||||
ldb,
|
||||
ldc: n as c_int,
|
||||
tiles_m,
|
||||
tiles_n,
|
||||
batch_stride_a,
|
||||
batch_stride_b,
|
||||
batch_stride_c,
|
||||
swizzle_log,
|
||||
gemm_k_iterations_aligned: (k / bk) as c_int,
|
||||
};
|
||||
let params_buffer = device.new_buffer_with_data(
|
||||
¶ms as *const GEMMParams as *const c_void,
|
||||
core::mem::size_of::<GEMMParams>() as u64,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
||||
encoder.set_buffer(2, Some(output), 0);
|
||||
encoder.set_buffer(3, Some(¶ms_buffer), 0);
|
||||
// TODO Tensor D
|
||||
|
||||
let grid_z = b;
|
||||
if batched {
|
||||
let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
|
||||
let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
|
||||
let byte_stride_c = m * n * bytes as usize;
|
||||
// TODO byte_stride_d
|
||||
let byte_stride_d = 0;
|
||||
// if batched {
|
||||
// let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
|
||||
// let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
|
||||
// let byte_stride_c = m * n * bytes as usize;
|
||||
// // TODO byte_stride_d
|
||||
// let byte_stride_d = 0;
|
||||
|
||||
let buffer: Vec<u64> = vec![
|
||||
byte_stride_a as _,
|
||||
byte_stride_b as _,
|
||||
byte_stride_c as _,
|
||||
byte_stride_d as _,
|
||||
];
|
||||
encoder.set_bytes(
|
||||
10,
|
||||
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
||||
buffer.as_ptr() as *const NSUInteger as *const c_void,
|
||||
);
|
||||
}
|
||||
// let buffer: Vec<u64> = vec![
|
||||
// byte_stride_a as _,
|
||||
// byte_stride_b as _,
|
||||
// byte_stride_c as _,
|
||||
// byte_stride_d as _,
|
||||
// ];
|
||||
// // encoder.set_bytes(
|
||||
// // 10,
|
||||
// // (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
||||
// // buffer.as_ptr() as *const NSUInteger as *const c_void,
|
||||
// // );
|
||||
// }
|
||||
let tile = 1 << swizzle_log;
|
||||
let tm = (tiles_m + tile - 1) / tile;
|
||||
let tn = tiles_n * tile;
|
||||
|
||||
let grid_size = MTLSize {
|
||||
width: divide(n, n_group.into()),
|
||||
height: divide(m, m_group.into()),
|
||||
width: tn as u64,
|
||||
height: tm as u64,
|
||||
depth: grid_z as NSUInteger,
|
||||
};
|
||||
let group_size = MTLSize {
|
||||
width: 32 * (m_splits as u64) * (n_splits as u64),
|
||||
height: 1,
|
||||
depth: 1,
|
||||
width: 32,
|
||||
height: wn,
|
||||
depth: wm,
|
||||
};
|
||||
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||
|
Reference in New Issue
Block a user