mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
1 Commits
0.9.0-alph
...
bf16_metal
Author | SHA1 | Date | |
---|---|---|---|
e2bf0adc2a |
2
candle-metal-kernels/compile.sh
Normal file
2
candle-metal-kernels/compile.sh
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
xcrun metal -c src/gemm/kernels/steel_gemm.metal -I src/
|
||||||
|
xcrun metallib steel_gemm.air -o src/gemm/steel_gemm.metallib
|
317
candle-metal-kernels/src/gemm/bf16.h
Normal file
317
candle-metal-kernels/src/gemm/bf16.h
Normal file
@ -0,0 +1,317 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
#if defined(__HAVE_BFLOAT__)
|
||||||
|
|
||||||
|
typedef bfloat bfloat16_t;
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Helpers
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
|
||||||
|
// Check for nan
|
||||||
|
if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
|
||||||
|
_fp_encoding_traits<float>::inf_mask) {
|
||||||
|
return uint16_t(as_type<uint32_t>(0x7FC0));
|
||||||
|
}
|
||||||
|
// Take bits
|
||||||
|
uint32_t float_bits = as_type<uint32_t>(x);
|
||||||
|
|
||||||
|
// Round to nearest even
|
||||||
|
float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
|
||||||
|
|
||||||
|
// Take upper 16 bits
|
||||||
|
return float_bits >> 16;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
|
||||||
|
// Upper 16 bits are the data and lower 16 bits are 0s
|
||||||
|
return as_type<float>((uint32_t)x << 16);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct _MLX_BFloat16;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static constexpr constant bool can_convert_to_bfloat =
|
||||||
|
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static constexpr constant bool can_convert_from_bfloat =
|
||||||
|
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Bfloat struct
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
struct _MLX_BFloat16 {
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Constructors
|
||||||
|
uint16_t bits_;
|
||||||
|
_MLX_BFloat16() thread = default;
|
||||||
|
_MLX_BFloat16() threadgroup = default;
|
||||||
|
_MLX_BFloat16() device = default;
|
||||||
|
_MLX_BFloat16() constant = default;
|
||||||
|
|
||||||
|
struct bits_to_bfloat_struct {};
|
||||||
|
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
|
||||||
|
return bits_to_bfloat_struct();
|
||||||
|
}
|
||||||
|
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
|
||||||
|
: bits_(bits) {}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Conversions to bfloat
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||||
|
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
|
||||||
|
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||||
|
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
|
||||||
|
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||||
|
constexpr METAL_FUNC _MLX_BFloat16(T x) device
|
||||||
|
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
||||||
|
constexpr METAL_FUNC _MLX_BFloat16(T x) constant
|
||||||
|
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Conversions from bfloat
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||||
|
constexpr METAL_FUNC operator T() const thread {
|
||||||
|
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||||
|
constexpr METAL_FUNC operator T() const threadgroup {
|
||||||
|
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||||
|
constexpr METAL_FUNC operator T() const device {
|
||||||
|
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
||||||
|
constexpr METAL_FUNC operator T() const constant {
|
||||||
|
return static_cast<T>(bfloat_bits_to_float(bits_));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Bfloat operators
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Unary ops
|
||||||
|
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
|
||||||
|
return -static_cast<float>(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Binary operators
|
||||||
|
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
||||||
|
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
|
||||||
|
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
||||||
|
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
|
||||||
|
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||||
|
} \
|
||||||
|
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
|
||||||
|
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Arithmetic Operators
|
||||||
|
#define bfloat_binop(_op_, _operator_) \
|
||||||
|
bfloat_binop_base( \
|
||||||
|
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
|
||||||
|
bfloat_binop_helper(_op_, _operator_, float, float, float); \
|
||||||
|
bfloat_binop_helper(_op_, _operator_, float, half, float); \
|
||||||
|
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
|
||||||
|
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
|
||||||
|
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
|
||||||
|
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
|
||||||
|
|
||||||
|
bfloat_binop(+, operator+);
|
||||||
|
bfloat_binop(-, operator-);
|
||||||
|
bfloat_binop(*, operator*);
|
||||||
|
bfloat_binop(/, operator/);
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Comparison ops
|
||||||
|
#define bfloat_compop(__op__, __operator__) \
|
||||||
|
bfloat_binop_base( \
|
||||||
|
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
|
||||||
|
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
|
||||||
|
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
|
||||||
|
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
||||||
|
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
||||||
|
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
||||||
|
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
||||||
|
|
||||||
|
bfloat_compop(>, operator>);
|
||||||
|
bfloat_compop(<, operator<);
|
||||||
|
bfloat_compop(>=, operator>=);
|
||||||
|
bfloat_compop(<=, operator<=);
|
||||||
|
bfloat_compop(==, operator==);
|
||||||
|
bfloat_compop(!=, operator!=);
|
||||||
|
|
||||||
|
#undef bfloat_compop
|
||||||
|
#undef bfloat_binop_base
|
||||||
|
#undef bfloat_binop_helper
|
||||||
|
#undef bfloat_binop
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Inplace Operators
|
||||||
|
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
|
||||||
|
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
||||||
|
addr_space _MLX_BFloat16& lhs, itype rhs) { \
|
||||||
|
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||||
|
return lhs; \
|
||||||
|
} \
|
||||||
|
constexpr METAL_FUNC addr_space itype& __operator__( \
|
||||||
|
addr_space itype& lhs, _MLX_BFloat16 rhs) { \
|
||||||
|
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||||
|
return lhs; \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
|
||||||
|
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
|
||||||
|
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
|
||||||
|
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
|
||||||
|
|
||||||
|
#define bfloat_inplace_op(itype) \
|
||||||
|
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
|
||||||
|
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
|
||||||
|
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
|
||||||
|
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
|
||||||
|
|
||||||
|
bfloat_inplace_op(float);
|
||||||
|
bfloat_inplace_op(half);
|
||||||
|
bfloat_inplace_op(int16_t);
|
||||||
|
bfloat_inplace_op(int32_t);
|
||||||
|
bfloat_inplace_op(int64_t);
|
||||||
|
bfloat_inplace_op(uint16_t);
|
||||||
|
bfloat_inplace_op(uint32_t);
|
||||||
|
bfloat_inplace_op(uint64_t);
|
||||||
|
|
||||||
|
#undef bfloat_inplace_op_helper
|
||||||
|
#undef bfloat_inplace_op_addr_space_helper
|
||||||
|
#undef bfloat_inplace_op
|
||||||
|
|
||||||
|
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
|
||||||
|
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
||||||
|
addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
|
||||||
|
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
||||||
|
return lhs; \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
|
||||||
|
bfloat_inplace_op_helper(__op__, __operator__, device); \
|
||||||
|
bfloat_inplace_op_helper(__op__, __operator__, thread); \
|
||||||
|
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
|
||||||
|
|
||||||
|
bfloat_inplace_op_addr_space_helper(+, operator+=);
|
||||||
|
bfloat_inplace_op_addr_space_helper(-, operator-=);
|
||||||
|
bfloat_inplace_op_addr_space_helper(*, operator*=);
|
||||||
|
bfloat_inplace_op_addr_space_helper(/, operator/=);
|
||||||
|
|
||||||
|
#undef bfloat_inplace_op_helper
|
||||||
|
#undef bfloat_inplace_op_addr_space_helper
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Bfloat typedef
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
typedef struct _MLX_BFloat16 bfloat16_t;
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Bfloat numeric limits
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#pragma METAL internals : enable
|
||||||
|
|
||||||
|
namespace metal {
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
|
||||||
|
static constexpr constant int digits = 8;
|
||||||
|
static constexpr constant int digits10 = 2;
|
||||||
|
static constexpr constant int max_digits10 = 4;
|
||||||
|
static constexpr constant int radix = 2;
|
||||||
|
static constexpr constant int min_exponent = -125;
|
||||||
|
static constexpr constant int min_exponent10 = -37;
|
||||||
|
static constexpr constant int max_exponent = 128;
|
||||||
|
static constexpr constant int max_exponent10 = 38;
|
||||||
|
|
||||||
|
static constexpr bfloat16_t min() {
|
||||||
|
return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat());
|
||||||
|
}
|
||||||
|
static constexpr bfloat16_t lowest() {
|
||||||
|
return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat());
|
||||||
|
}
|
||||||
|
static constexpr bfloat16_t max() {
|
||||||
|
return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat());
|
||||||
|
}
|
||||||
|
static constexpr bfloat16_t epsilon() {
|
||||||
|
return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat());
|
||||||
|
}
|
||||||
|
static constexpr bfloat16_t round_error() {
|
||||||
|
return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat());
|
||||||
|
}
|
||||||
|
static constexpr bfloat16_t infinity() {
|
||||||
|
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
||||||
|
}
|
||||||
|
static constexpr bfloat16_t quiet_NaN() {
|
||||||
|
return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat());
|
||||||
|
}
|
||||||
|
static constexpr bfloat16_t signaling_NaN() {
|
||||||
|
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
||||||
|
}
|
||||||
|
static constexpr bfloat16_t denorm_min() {
|
||||||
|
return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
METAL_FUNC bool isnan(_MLX_BFloat16 x) {
|
||||||
|
return x != x;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace metal
|
||||||
|
|
||||||
|
#pragma METAL internals : disable
|
||||||
|
|
||||||
|
#endif // defined(__HAVE_BFLOAT__)
|
||||||
|
|
||||||
|
#include "gemm/bf16_math.h"
|
394
candle-metal-kernels/src/gemm/bf16_math.h
Normal file
394
candle-metal-kernels/src/gemm/bf16_math.h
Normal file
@ -0,0 +1,394 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "gemm/bf16.h"
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Metal math for bfloat16
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
Following the Metal Shading Language Specification (Metal 3.1)
|
||||||
|
|
||||||
|
"bfloat is an extended itypeing point type that only allows implicit conversion
|
||||||
|
to a type of greater itypeing point rank. While bfloat can be implicitly
|
||||||
|
converted to itype, it cannot be implicitly converted to half, and neither
|
||||||
|
itype nor half can be implicitly converted to bfloat."
|
||||||
|
|
||||||
|
Further, as far as I can tell, the stdlib math/simd functions are not defined
|
||||||
|
for bfloat and calling with an argument of type bfloat will result in that
|
||||||
|
argument getting implicitly converted to itype which then returns an output
|
||||||
|
that is (likely) a itype which cannot be implicitly converted into a bfloat
|
||||||
|
|
||||||
|
This leads to situations where
|
||||||
|
bfloat a = 5.0bf;
|
||||||
|
bfloat b = metal::abs(a); // this will throw an error since abs return itype
|
||||||
|
bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
|
||||||
|
|
||||||
|
For the moment, I will be adding overloaded instantiations of the math
|
||||||
|
functions to accordingly automatically handle the casting
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype abs(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype acos(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype acosh(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype asin(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype asinh(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype atan(itype y_over_x) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_atan(static_cast<ctype>(y_over_x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype atan2(itype y, itype x) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype atanh(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype ceil(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype cos(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype cosh(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype cospi(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype divide(itype x, itype y) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype exp(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype exp10(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype exp2(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fabs(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fdim(itype x, itype y) { \
|
||||||
|
ctype t = static_cast<ctype>(x - y); \
|
||||||
|
return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype floor(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fma(itype x, itype y, itype z) { \
|
||||||
|
return static_cast<otype>(__metal_fma( \
|
||||||
|
static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fmax(itype x, itype y) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
|
||||||
|
return static_cast<otype>(__metal_fmax3( \
|
||||||
|
static_cast<ctype>(x), \
|
||||||
|
static_cast<ctype>(y), \
|
||||||
|
static_cast<ctype>(z), \
|
||||||
|
mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
|
||||||
|
return static_cast<otype>(__metal_fmedian3( \
|
||||||
|
static_cast<ctype>(x), \
|
||||||
|
static_cast<ctype>(y), \
|
||||||
|
static_cast<ctype>(z), \
|
||||||
|
mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fmin(itype x, itype y) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
|
||||||
|
return static_cast<otype>(__metal_fmin3( \
|
||||||
|
static_cast<ctype>(x), \
|
||||||
|
static_cast<ctype>(y), \
|
||||||
|
static_cast<ctype>(z), \
|
||||||
|
mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fmod(itype x, itype y) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fract(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype frexp(itype x, thread int& exp) { \
|
||||||
|
return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype ldexp(itype x, int k) { \
|
||||||
|
return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype log(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype log10(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype log2(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype max(itype x, itype y) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype max3(itype x, itype y, itype z) { \
|
||||||
|
return static_cast<otype>(__metal_fmax3( \
|
||||||
|
static_cast<ctype>(x), \
|
||||||
|
static_cast<ctype>(y), \
|
||||||
|
static_cast<ctype>(z), \
|
||||||
|
mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype median3(itype x, itype y, itype z) { \
|
||||||
|
return static_cast<otype>(__metal_fmedian3( \
|
||||||
|
static_cast<ctype>(x), \
|
||||||
|
static_cast<ctype>(y), \
|
||||||
|
static_cast<ctype>(z), \
|
||||||
|
mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype min(itype x, itype y) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype min3(itype x, itype y, itype z) { \
|
||||||
|
return static_cast<otype>(__metal_fmin3( \
|
||||||
|
static_cast<ctype>(x), \
|
||||||
|
static_cast<ctype>(y), \
|
||||||
|
static_cast<ctype>(z), \
|
||||||
|
mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype nextafter(itype x, itype y) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype pow(itype x, itype y) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype powr(itype x, itype y) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype rint(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype round(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype rsqrt(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype sin(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype sinh(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype sinpi(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype sqrt(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype tan(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype tanh(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype tanpi(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype trunc(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace metal {
|
||||||
|
|
||||||
|
instantiate_metal_math_funcs(
|
||||||
|
bfloat16_t,
|
||||||
|
bfloat16_t,
|
||||||
|
float,
|
||||||
|
__METAL_MAYBE_FAST_MATH__);
|
||||||
|
|
||||||
|
namespace fast {
|
||||||
|
|
||||||
|
instantiate_metal_math_funcs(
|
||||||
|
bfloat16_t,
|
||||||
|
bfloat16_t,
|
||||||
|
float,
|
||||||
|
__METAL_FAST_MATH__);
|
||||||
|
|
||||||
|
} // namespace fast
|
||||||
|
|
||||||
|
namespace precise {
|
||||||
|
|
||||||
|
instantiate_metal_math_funcs(
|
||||||
|
bfloat16_t,
|
||||||
|
bfloat16_t,
|
||||||
|
float,
|
||||||
|
__METAL_PRECISE_MATH__);
|
||||||
|
|
||||||
|
} // namespace precise
|
||||||
|
|
||||||
|
} // namespace metal
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Metal simd for bfloat16
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#define instantiate_metal_simd_comm_funcs( \
|
||||||
|
itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
|
||||||
|
return ctype_to_otype( \
|
||||||
|
__metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
|
||||||
|
return ctype_to_otype( \
|
||||||
|
__metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
||||||
|
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
||||||
|
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
||||||
|
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
||||||
|
itype data, itype filling_data, ushort delta) { \
|
||||||
|
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
||||||
|
itype_to_ctype(data), \
|
||||||
|
itype_to_ctype(filling_data), \
|
||||||
|
delta, \
|
||||||
|
__metal_get_simdgroup_size(ushort()))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
||||||
|
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
||||||
|
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
||||||
|
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
||||||
|
itype data, itype filling_data, ushort delta) { \
|
||||||
|
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
||||||
|
itype_to_ctype(data), \
|
||||||
|
itype_to_ctype(filling_data), \
|
||||||
|
delta, \
|
||||||
|
__metal_get_simdgroup_size(ushort()))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
|
||||||
|
return ctype_to_otype( \
|
||||||
|
__metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
|
||||||
|
return ctype_to_otype( \
|
||||||
|
__metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
|
||||||
|
return ctype_to_otype( \
|
||||||
|
__metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
|
||||||
|
return ctype_to_otype( \
|
||||||
|
__metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
|
||||||
|
return ctype_to_otype( \
|
||||||
|
__metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_max(itype data) { \
|
||||||
|
return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_min(itype data) { \
|
||||||
|
return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_product(itype data) { \
|
||||||
|
return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_sum(itype data) { \
|
||||||
|
return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_xor(itype data) { \
|
||||||
|
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(__HAVE_BFLOAT__)
|
||||||
|
|
||||||
|
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
||||||
|
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
#define bfloat16_to_uint16(x) x.bits_
|
||||||
|
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace metal {
|
||||||
|
|
||||||
|
instantiate_metal_simd_comm_funcs(
|
||||||
|
bfloat16_t,
|
||||||
|
bfloat16_t,
|
||||||
|
uint16_t,
|
||||||
|
bfloat16_to_uint16,
|
||||||
|
uint16_to_bfloat16);
|
||||||
|
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
|
||||||
|
|
||||||
|
} // namespace metal
|
131
candle-metal-kernels/src/gemm/complex.h
Normal file
131
candle-metal-kernels/src/gemm/complex.h
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
struct complex64_t;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static constexpr constant bool can_convert_to_complex64 =
|
||||||
|
!is_same_v<T, complex64_t> && is_convertible_v<T, float>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static constexpr constant bool can_convert_from_complex64 =
|
||||||
|
!is_same_v<T, complex64_t> &&
|
||||||
|
(is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
|
||||||
|
|
||||||
|
struct complex64_t {
|
||||||
|
float real;
|
||||||
|
float imag;
|
||||||
|
|
||||||
|
// Constructors
|
||||||
|
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
|
||||||
|
|
||||||
|
// Conversions to complex64_t
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||||
|
constexpr complex64_t(T x) thread : real(x), imag(0) {}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||||
|
constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||||
|
constexpr complex64_t(T x) device : real(x), imag(0) {}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
||||||
|
constexpr complex64_t(T x) constant : real(x), imag(0) {}
|
||||||
|
|
||||||
|
// Conversions from complex64_t
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||||
|
constexpr operator T() const thread {
|
||||||
|
return static_cast<T>(real);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||||
|
constexpr operator T() const threadgroup {
|
||||||
|
return static_cast<T>(real);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||||
|
constexpr operator T() const device {
|
||||||
|
return static_cast<T>(real);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
||||||
|
constexpr operator T() const constant {
|
||||||
|
return static_cast<T>(real);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
constexpr complex64_t operator-(complex64_t x) {
|
||||||
|
return {-x.real, -x.imag};
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr bool operator>=(complex64_t a, complex64_t b) {
|
||||||
|
return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr bool operator>(complex64_t a, complex64_t b) {
|
||||||
|
return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr bool operator<=(complex64_t a, complex64_t b) {
|
||||||
|
return operator>=(b, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr bool operator<(complex64_t a, complex64_t b) {
|
||||||
|
return operator>(b, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr bool operator==(complex64_t a, complex64_t b) {
|
||||||
|
return a.real == b.real && a.imag == b.imag;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
||||||
|
return {a.real + b.real, a.imag + b.imag};
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
||||||
|
return {a.real - b.real, a.imag - b.imag};
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
||||||
|
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr complex64_t operator/(complex64_t a, complex64_t b) {
|
||||||
|
auto denom = b.real * b.real + b.imag * b.imag;
|
||||||
|
auto x = a.real * b.real + a.imag * b.imag;
|
||||||
|
auto y = a.imag * b.real - a.real * b.imag;
|
||||||
|
return {x / denom, y / denom};
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
||||||
|
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
||||||
|
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
||||||
|
if (real != 0 && (real < 0 != b.real < 0)) {
|
||||||
|
real += b.real;
|
||||||
|
}
|
||||||
|
if (imag != 0 && (imag < 0 != b.imag < 0)) {
|
||||||
|
imag += b.imag;
|
||||||
|
}
|
||||||
|
return {real, imag};
|
||||||
|
}
|
292
candle-metal-kernels/src/gemm/gemm.h
Normal file
292
candle-metal-kernels/src/gemm/gemm.h
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "gemm/loader.h"
|
||||||
|
#include "gemm/mma.h"
|
||||||
|
#include "gemm/transforms.h"
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernel class
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace mlx {
|
||||||
|
namespace steel {
|
||||||
|
|
||||||
|
template <bool M_aligned, bool N_aligned, bool K_aligned>
|
||||||
|
struct LoopAlignment {};
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
bool MN_aligned,
|
||||||
|
bool K_aligned,
|
||||||
|
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||||
|
typename Epilogue = TransformNone<U, AccumType>>
|
||||||
|
struct GEMMKernel {
|
||||||
|
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
|
||||||
|
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
|
||||||
|
STEEL_CONST short tgp_mem_size_a =
|
||||||
|
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
||||||
|
STEEL_CONST short tgp_mem_size_b =
|
||||||
|
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
||||||
|
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
||||||
|
|
||||||
|
STEEL_CONST short tgp_size = WM * WN * 32;
|
||||||
|
|
||||||
|
using loader_a_t = BlockLoader<
|
||||||
|
T,
|
||||||
|
transpose_a ? BK : BM,
|
||||||
|
transpose_a ? BM : BK,
|
||||||
|
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
||||||
|
!transpose_a,
|
||||||
|
tgp_size>;
|
||||||
|
using loader_b_t = BlockLoader<
|
||||||
|
T,
|
||||||
|
transpose_b ? BN : BK,
|
||||||
|
transpose_b ? BK : BN,
|
||||||
|
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
||||||
|
transpose_b,
|
||||||
|
tgp_size>;
|
||||||
|
using mma_t = BlockMMA<
|
||||||
|
T,
|
||||||
|
U,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
||||||
|
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
||||||
|
AccumType,
|
||||||
|
Epilogue>;
|
||||||
|
|
||||||
|
/* Main kernel function */
|
||||||
|
template <bool M_aligned, bool N_aligned, bool K_aligned_>
|
||||||
|
static METAL_FUNC void gemm_loop(
|
||||||
|
threadgroup T* As [[threadgroup(0)]],
|
||||||
|
threadgroup T* Bs [[threadgroup(1)]],
|
||||||
|
const int gemm_k_iterations,
|
||||||
|
thread loader_a_t& loader_a,
|
||||||
|
thread loader_b_t& loader_b,
|
||||||
|
thread mma_t& mma_op,
|
||||||
|
thread const short& tgp_bm,
|
||||||
|
thread const short& tgp_bn,
|
||||||
|
thread const short& lbk,
|
||||||
|
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
|
||||||
|
// Appease the compiler
|
||||||
|
(void)l;
|
||||||
|
|
||||||
|
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||||
|
|
||||||
|
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||||
|
|
||||||
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Load elements into threadgroup
|
||||||
|
if (M_aligned) {
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
} else {
|
||||||
|
loader_a.load_safe(tile_dims_A);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (N_aligned) {
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
} else {
|
||||||
|
loader_b.load_safe(tile_dims_B);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!K_aligned_) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
short2 tile_dims_A_last =
|
||||||
|
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
||||||
|
short2 tile_dims_B_last =
|
||||||
|
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
||||||
|
|
||||||
|
loader_a.load_safe(tile_dims_A_last);
|
||||||
|
loader_b.load_safe(tile_dims_B_last);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Main kernel function */
|
||||||
|
static METAL_FUNC void run(
|
||||||
|
const device T* A [[buffer(0)]],
|
||||||
|
const device T* B [[buffer(1)]],
|
||||||
|
device U* C [[buffer(2)]],
|
||||||
|
const constant GEMMParams* params [[buffer(3)]],
|
||||||
|
threadgroup T* As [[threadgroup(0)]],
|
||||||
|
threadgroup T* Bs [[threadgroup(1)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
// Pacifying compiler
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||||
|
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||||
|
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||||
|
|
||||||
|
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Find block in A, B, C
|
||||||
|
const int c_row = tid_y * BM;
|
||||||
|
const int c_col = tid_x * BN;
|
||||||
|
|
||||||
|
A += transpose_a ? c_row : c_row * params->lda;
|
||||||
|
B += transpose_b ? c_col * params->ldb : c_col;
|
||||||
|
C += c_row * params->ldc + c_col;
|
||||||
|
|
||||||
|
// Prepare threadgroup loading operations
|
||||||
|
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||||
|
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
// Prepare threadgroup mma operation
|
||||||
|
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MNK aligned loop
|
||||||
|
if (MN_aligned) {
|
||||||
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Loop tail
|
||||||
|
if (!K_aligned) {
|
||||||
|
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||||
|
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||||
|
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||||
|
|
||||||
|
loader_a.load_safe(tile_dims_A);
|
||||||
|
loader_b.load_safe(tile_dims_B);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
mma_op.store_result(C, params->ldc);
|
||||||
|
return;
|
||||||
|
|
||||||
|
}
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MN unaligned loop
|
||||||
|
else { // Loop over K - unaligned case
|
||||||
|
short tgp_bm = min(BM, params->M - c_row);
|
||||||
|
short tgp_bn = min(BN, params->N - c_col);
|
||||||
|
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||||
|
|
||||||
|
if (tgp_bm == BM && tgp_bn == BN) {
|
||||||
|
gemm_loop<true, true, K_aligned>(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk);
|
||||||
|
|
||||||
|
mma_op.store_result(C, params->ldc);
|
||||||
|
return;
|
||||||
|
|
||||||
|
} else if (tgp_bn == BN) {
|
||||||
|
gemm_loop<false, true, K_aligned>(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk);
|
||||||
|
|
||||||
|
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||||
|
return;
|
||||||
|
|
||||||
|
} else if (tgp_bm == BM) {
|
||||||
|
gemm_loop<true, false, K_aligned>(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk);
|
||||||
|
|
||||||
|
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||||
|
return;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
gemm_loop<false, false, K_aligned>(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk);
|
||||||
|
|
||||||
|
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace steel
|
||||||
|
} // namespace mlx
|
5
candle-metal-kernels/src/gemm/host.h
Normal file
5
candle-metal-kernels/src/gemm/host.h
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "params.h"
|
89
candle-metal-kernels/src/gemm/kernels/steel_gemm.metal
Normal file
89
candle-metal-kernels/src/gemm/kernels/steel_gemm.metal
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "gemm/bf16.h"
|
||||||
|
#include "gemm/gemm.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernels
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
bool MN_aligned,
|
||||||
|
bool K_aligned>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
|
||||||
|
const device T *A [[buffer(0)]],
|
||||||
|
const device T *B [[buffer(1)]],
|
||||||
|
device T *C [[buffer(2)]],
|
||||||
|
const constant GEMMParams* params [[buffer(3)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
|
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
||||||
|
|
||||||
|
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||||
|
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||||
|
|
||||||
|
// Adjust for batch
|
||||||
|
A += params->batch_stride_a * tid.z;
|
||||||
|
B += params->batch_stride_b * tid.z;
|
||||||
|
C += params->batch_stride_c * tid.z;
|
||||||
|
|
||||||
|
gemm_kernel::run(
|
||||||
|
A, B, C,
|
||||||
|
params,
|
||||||
|
As, Bs,
|
||||||
|
simd_lane_id, simd_group_id, tid, lid
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernel initializations
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
|
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
||||||
|
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||||
|
const device itype *A [[buffer(0)]], \
|
||||||
|
const device itype *B [[buffer(1)]], \
|
||||||
|
device itype *C [[buffer(2)]], \
|
||||||
|
const constant GEMMParams* params [[buffer(3)]], \
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||||
|
|
||||||
|
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||||
|
|
||||||
|
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
||||||
|
|
||||||
|
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||||
|
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||||
|
|
||||||
|
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
254
candle-metal-kernels/src/gemm/kernels/steel_gemm_addmm.metal
Normal file
254
candle-metal-kernels/src/gemm/kernels/steel_gemm_addmm.metal
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernels
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
bool MN_aligned,
|
||||||
|
bool K_aligned,
|
||||||
|
typename AccumType = float,
|
||||||
|
typename Epilogue = TransformAdd<T, AccumType>>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void addmm(
|
||||||
|
const device T *A [[buffer(0)]],
|
||||||
|
const device T *B [[buffer(1)]],
|
||||||
|
const device T *C [[buffer(2)]],
|
||||||
|
device T *D [[buffer(3)]],
|
||||||
|
const constant GEMMAddMMParams* params [[buffer(4)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
|
// Pacifying compiler
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
using gemm_kernel =
|
||||||
|
GEMMKernel<T, T, BM, BN, BK, WM, WN,
|
||||||
|
transpose_a, transpose_b,
|
||||||
|
MN_aligned, K_aligned,
|
||||||
|
AccumType, Epilogue>;
|
||||||
|
|
||||||
|
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||||
|
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||||
|
using mma_t = typename gemm_kernel::mma_t;
|
||||||
|
|
||||||
|
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||||
|
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||||
|
|
||||||
|
// Adjust for batch
|
||||||
|
A += params->batch_stride_a * tid.z;
|
||||||
|
B += params->batch_stride_b * tid.z;
|
||||||
|
C += params->batch_stride_c * tid.z;
|
||||||
|
D += params->batch_stride_d * tid.z;
|
||||||
|
|
||||||
|
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||||
|
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||||
|
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||||
|
|
||||||
|
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Find block in A, B, C
|
||||||
|
const int c_row = tid_y * BM;
|
||||||
|
const int c_col = tid_x * BN;
|
||||||
|
|
||||||
|
A += transpose_a ? c_row : c_row * params->lda;
|
||||||
|
B += transpose_b ? c_col * params->ldb : c_col;
|
||||||
|
C += c_row * params->ldc + c_col * params->fdc;
|
||||||
|
D += c_row * params->ldd + c_col;
|
||||||
|
|
||||||
|
// Prepare threadgroup loading operations
|
||||||
|
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||||
|
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
// Prepare threadgroup mma operation
|
||||||
|
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||||
|
|
||||||
|
const Epilogue epilogue_op(params->alpha, params->beta);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MNK aligned loop
|
||||||
|
if (MN_aligned) {
|
||||||
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Loop tail
|
||||||
|
if (!K_aligned) {
|
||||||
|
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||||
|
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||||
|
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||||
|
|
||||||
|
loader_a.load_safe(tile_dims_A);
|
||||||
|
loader_b.load_safe(tile_dims_B);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
|
||||||
|
return;
|
||||||
|
|
||||||
|
}
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MN unaligned loop
|
||||||
|
else { // Loop over K - unaligned case
|
||||||
|
short tgp_bm = min(BM, params->M - c_row);
|
||||||
|
short tgp_bn = min(BN, params->N - c_col);
|
||||||
|
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||||
|
|
||||||
|
if (tgp_bm == BM && tgp_bn == BN) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<true, true, K_aligned>{});
|
||||||
|
|
||||||
|
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
|
||||||
|
return;
|
||||||
|
|
||||||
|
} else if (tgp_bn == BN) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<false, true, K_aligned>{});
|
||||||
|
|
||||||
|
return mma_op.store_result_safe(
|
||||||
|
D, params->ldd,
|
||||||
|
C, params->ldc, params->fdc,
|
||||||
|
short2(tgp_bn, tgp_bm),
|
||||||
|
epilogue_op);
|
||||||
|
|
||||||
|
} else if (tgp_bm == BM) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<true, false, K_aligned>{});
|
||||||
|
|
||||||
|
return mma_op.store_result_safe(
|
||||||
|
D, params->ldd,
|
||||||
|
C, params->ldc, params->fdc,
|
||||||
|
short2(tgp_bn, tgp_bm),
|
||||||
|
epilogue_op);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<false, false, K_aligned>{});
|
||||||
|
|
||||||
|
return mma_op.store_result_safe(
|
||||||
|
D, params->ldd,
|
||||||
|
C, params->ldc, params->fdc,
|
||||||
|
short2(tgp_bn, tgp_bm),
|
||||||
|
epilogue_op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernel initializations
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \
|
||||||
|
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \
|
||||||
|
[[kernel]] void addmm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, float, epilogue<itype, float>>( \
|
||||||
|
const device itype *A [[buffer(0)]], \
|
||||||
|
const device itype *B [[buffer(1)]], \
|
||||||
|
const device itype *C [[buffer(2)]], \
|
||||||
|
device itype *D [[buffer(3)]], \
|
||||||
|
const constant GEMMAddMMParams* params [[buffer(4)]], \
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby)
|
||||||
|
|
||||||
|
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||||
|
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||||
|
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||||
|
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||||
|
|
||||||
|
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||||
|
|
||||||
|
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
||||||
|
|
||||||
|
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||||
|
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||||
|
|
||||||
|
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
280
candle-metal-kernels/src/gemm/kernels/steel_gemm_splitk.metal
Normal file
280
candle-metal-kernels/src/gemm/kernels/steel_gemm_splitk.metal
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernels
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T,
|
||||||
|
typename U,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
bool MN_aligned,
|
||||||
|
bool K_aligned>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk(
|
||||||
|
const device T *A [[buffer(0)]],
|
||||||
|
const device T *B [[buffer(1)]],
|
||||||
|
device U *C [[buffer(2)]],
|
||||||
|
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
using gemm_kernel = GEMMKernel<T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
||||||
|
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||||
|
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||||
|
using mma_t = typename gemm_kernel::mma_t;
|
||||||
|
|
||||||
|
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||||
|
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||||
|
|
||||||
|
const int tid_x = tid.x;
|
||||||
|
const int tid_y = tid.y;
|
||||||
|
const int tid_z = tid.z;
|
||||||
|
|
||||||
|
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find block in A, B, C
|
||||||
|
const int c_row = tid_y * BM;
|
||||||
|
const int c_col = tid_x * BN;
|
||||||
|
const int k_start = params->split_k_partition_size * tid_z;
|
||||||
|
|
||||||
|
A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda);
|
||||||
|
B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb);
|
||||||
|
C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col);
|
||||||
|
|
||||||
|
// Prepare threadgroup loading operations
|
||||||
|
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||||
|
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
// Prepare threadgroup mma operation
|
||||||
|
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||||
|
|
||||||
|
short tgp_bm = min(BM, params->M - c_row);
|
||||||
|
short tgp_bn = min(BN, params->N - c_col);
|
||||||
|
short leftover_bk = params->K % BK;
|
||||||
|
|
||||||
|
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<true, true, true>{});
|
||||||
|
} else if (tgp_bn == BN) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<false, true, true>{});
|
||||||
|
} else if (tgp_bm == BM) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<true, false, true>{});
|
||||||
|
} else {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<false, false, true>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if ((tid_z + 1) == (params->split_k_partitions)) {
|
||||||
|
int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK;
|
||||||
|
if(!K_aligned || gemm_k_iter_remaining > 0)
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iter_remaining,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<false, false, K_aligned>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||||
|
mma_op.store_result(C, params->ldc);
|
||||||
|
} else {
|
||||||
|
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernel initializations
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
|
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
||||||
|
[[kernel]] void gemm_splitk<itype, otype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||||
|
const device itype *A [[buffer(0)]], \
|
||||||
|
const device itype *B [[buffer(1)]], \
|
||||||
|
device otype *C [[buffer(2)]], \
|
||||||
|
const constant GEMMSpiltKParams* params [[buffer(3)]], \
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||||
|
|
||||||
|
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||||
|
|
||||||
|
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
||||||
|
|
||||||
|
instantiate_gemm_shapes_helper(float16, half, float32, float);
|
||||||
|
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
|
||||||
|
|
||||||
|
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Split k accumulation kernel
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename AccT,
|
||||||
|
typename OutT,
|
||||||
|
typename Epilogue = TransformNone<OutT, AccT>>
|
||||||
|
[[kernel]] void gemm_splitk_accum(
|
||||||
|
const device AccT *C_split [[buffer(0)]],
|
||||||
|
device OutT *D [[buffer(1)]],
|
||||||
|
const constant int& k_partitions [[buffer(2)]],
|
||||||
|
const constant int& partition_stride [[buffer(3)]],
|
||||||
|
const constant int& ldd [[buffer(4)]],
|
||||||
|
uint2 gid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
|
// Ajust D and C
|
||||||
|
D += gid.x + gid.y * ldd;
|
||||||
|
C_split += gid.x + gid.y * ldd;
|
||||||
|
|
||||||
|
int offset = 0;
|
||||||
|
AccT out = 0;
|
||||||
|
|
||||||
|
for(int i = 0; i < k_partitions; i++) {
|
||||||
|
out += C_split[offset];
|
||||||
|
offset += partition_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write output
|
||||||
|
D[0] = Epilogue::apply(out);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename AccT,
|
||||||
|
typename OutT,
|
||||||
|
typename Epilogue = TransformAxpby<OutT, AccT>>
|
||||||
|
[[kernel]] void gemm_splitk_accum_axpby(
|
||||||
|
const device AccT *C_split [[buffer(0)]],
|
||||||
|
device OutT *D [[buffer(1)]],
|
||||||
|
const constant int& k_partitions [[buffer(2)]],
|
||||||
|
const constant int& partition_stride [[buffer(3)]],
|
||||||
|
const constant int& ldd [[buffer(4)]],
|
||||||
|
const device OutT *C [[buffer(5)]],
|
||||||
|
const constant int& ldc [[buffer(6)]],
|
||||||
|
const constant int& fdc [[buffer(7)]],
|
||||||
|
const constant float& alpha [[buffer(8)]],
|
||||||
|
const constant float& beta [[buffer(9)]],
|
||||||
|
uint2 gid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
|
// Ajust D and C
|
||||||
|
C += gid.x * fdc + gid.y * ldc;
|
||||||
|
D += gid.x + gid.y * ldd;
|
||||||
|
C_split += gid.x + gid.y * ldd;
|
||||||
|
|
||||||
|
int offset = 0;
|
||||||
|
AccT out = 0;
|
||||||
|
|
||||||
|
for(int i = 0; i < k_partitions; i++) {
|
||||||
|
out += C_split[offset];
|
||||||
|
offset += partition_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write output
|
||||||
|
Epilogue op(alpha, beta);
|
||||||
|
D[0] = op.apply(out, *C);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_accum(oname, otype, aname, atype) \
|
||||||
|
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname)]] \
|
||||||
|
[[kernel]] void gemm_splitk_accum<atype, otype>( \
|
||||||
|
const device atype *C_split [[buffer(0)]], \
|
||||||
|
device otype *D [[buffer(1)]], \
|
||||||
|
const constant int& k_partitions [[buffer(2)]], \
|
||||||
|
const constant int& partition_stride [[buffer(3)]], \
|
||||||
|
const constant int& ldd [[buffer(4)]], \
|
||||||
|
uint2 gid [[thread_position_in_grid]]); \
|
||||||
|
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \
|
||||||
|
[[kernel]] void gemm_splitk_accum_axpby<atype, otype>( \
|
||||||
|
const device atype *C_split [[buffer(0)]], \
|
||||||
|
device otype *D [[buffer(1)]], \
|
||||||
|
const constant int& k_partitions [[buffer(2)]], \
|
||||||
|
const constant int& partition_stride [[buffer(3)]], \
|
||||||
|
const constant int& ldd [[buffer(4)]], \
|
||||||
|
const device otype *C [[buffer(5)]], \
|
||||||
|
const constant int& ldc [[buffer(6)]], \
|
||||||
|
const constant int& fdc [[buffer(7)]], \
|
||||||
|
const constant float& alpha [[buffer(8)]], \
|
||||||
|
const constant float& beta [[buffer(9)]], \
|
||||||
|
uint2 gid [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
instantiate_accum(bfloat16, bfloat16_t, float32, float);
|
||||||
|
instantiate_accum(float16, half, float32, float);
|
||||||
|
instantiate_accum(float32, float, float32, float);
|
125
candle-metal-kernels/src/gemm/loader.h
Normal file
125
candle-metal-kernels/src/gemm/loader.h
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "utils2.h"
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Loading helper
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace mlx {
|
||||||
|
namespace steel {
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
short BROWS,
|
||||||
|
short BCOLS,
|
||||||
|
short dst_ld,
|
||||||
|
short reduction_dim,
|
||||||
|
short tgp_size,
|
||||||
|
short alignment = 1,
|
||||||
|
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
||||||
|
short TCOLS = BCOLS / n_reads,
|
||||||
|
short TROWS = tgp_size / TCOLS>
|
||||||
|
struct BlockLoader {
|
||||||
|
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
||||||
|
STEEL_CONST short vec_size = n_reads;
|
||||||
|
|
||||||
|
// Leading dimension for src
|
||||||
|
const int src_ld;
|
||||||
|
const int tile_stride;
|
||||||
|
|
||||||
|
// Thread location indices
|
||||||
|
const short thread_idx;
|
||||||
|
const short bi;
|
||||||
|
const short bj;
|
||||||
|
|
||||||
|
// threadgroup and device memory
|
||||||
|
threadgroup T* dst;
|
||||||
|
const device T* src;
|
||||||
|
|
||||||
|
struct alignas(alignment * sizeof(T)) ReadVector {
|
||||||
|
uint8_t v[sizeof(T) * vec_size];
|
||||||
|
};
|
||||||
|
|
||||||
|
/* Constructor */
|
||||||
|
METAL_FUNC BlockLoader(
|
||||||
|
const device T* src_,
|
||||||
|
const int src_ld_,
|
||||||
|
threadgroup T* dst_,
|
||||||
|
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||||
|
: src_ld(src_ld_),
|
||||||
|
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
||||||
|
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||||
|
bi(thread_idx / TCOLS),
|
||||||
|
bj(vec_size * (thread_idx % TCOLS)),
|
||||||
|
dst(dst_ + bi * dst_ld + bj),
|
||||||
|
src(src_ + bi * src_ld + bj) {}
|
||||||
|
|
||||||
|
/* Load from device memory into threadgroup memory - without bound checking */
|
||||||
|
METAL_FUNC void load_unsafe() const {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < BROWS; i += TROWS) {
|
||||||
|
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
|
||||||
|
*((const device ReadVector*)(&src[i * src_ld]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Load from device memory into threadgroup memory - with bound checking */
|
||||||
|
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||||
|
src_tile_dim = src_tile_dim - short2(bj, bi);
|
||||||
|
|
||||||
|
// Skip loading if thread has no valid reads
|
||||||
|
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < BROWS; i += TROWS) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
dst[i * dst_ld + j] = T(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use fast thread memory for bound checks
|
||||||
|
bool tmp_idx[vec_size];
|
||||||
|
T tmp_val[vec_size];
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < BROWS; i += TROWS) {
|
||||||
|
// Make sure tmp_idx only contains valid indices
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read valid indices into tmp_val
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zero out uneeded values
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy values to threadgroup memory
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
dst[i * dst_ld + j] = tmp_val[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Iteration helper */
|
||||||
|
METAL_FUNC void next() {
|
||||||
|
src += tile_stride;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace steel
|
||||||
|
} // namespace mlx
|
264
candle-metal-kernels/src/gemm/mma.h
Normal file
264
candle-metal-kernels/src/gemm/mma.h
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "gemm/transforms.h"
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MMA helper
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace mlx {
|
||||||
|
namespace steel {
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
short lda_tgp,
|
||||||
|
short ldb_tgp,
|
||||||
|
typename AccumType = float,
|
||||||
|
typename Epilogue = TransformNone<U, AccumType>>
|
||||||
|
struct BlockMMA {
|
||||||
|
// Warp tile simdgroup matrix strides along M
|
||||||
|
STEEL_CONST short TM_stride = 8 * WM;
|
||||||
|
// Warp tile simdgroup matrix strides along M
|
||||||
|
STEEL_CONST short TN_stride = 8 * WN;
|
||||||
|
|
||||||
|
// Warp tile size along M
|
||||||
|
STEEL_CONST short TM = BM / TM_stride;
|
||||||
|
// Warp tile size along N
|
||||||
|
STEEL_CONST short TN = BN / TN_stride;
|
||||||
|
|
||||||
|
// Strides of A, B along reduction axis
|
||||||
|
STEEL_CONST short simd_stride_a = {
|
||||||
|
transpose_a ? TM_stride : TM_stride * lda_tgp};
|
||||||
|
STEEL_CONST short simd_stride_b = {
|
||||||
|
transpose_b ? TN_stride * ldb_tgp : TN_stride};
|
||||||
|
|
||||||
|
// Jump between elements
|
||||||
|
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
|
||||||
|
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
|
||||||
|
|
||||||
|
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
|
||||||
|
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
|
||||||
|
|
||||||
|
// Simdgroup matrices
|
||||||
|
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
||||||
|
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
||||||
|
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
||||||
|
simdgroup_matrix<AccumType, 8, 8>(0)};
|
||||||
|
|
||||||
|
// Offsets within threadgroup
|
||||||
|
const short tm;
|
||||||
|
const short tn;
|
||||||
|
|
||||||
|
short sm;
|
||||||
|
short sn;
|
||||||
|
|
||||||
|
short As_offset;
|
||||||
|
short Bs_offset;
|
||||||
|
|
||||||
|
/* Constructor */
|
||||||
|
METAL_FUNC BlockMMA(
|
||||||
|
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||||
|
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
||||||
|
// Determine thread position in simdgroup matrix
|
||||||
|
short qid = simd_lane_id / 4;
|
||||||
|
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||||
|
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||||
|
|
||||||
|
// Determine thread and simdgroup offset
|
||||||
|
As_offset =
|
||||||
|
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
|
||||||
|
Bs_offset =
|
||||||
|
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||||
|
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
||||||
|
// Adjust for simdgroup and thread location
|
||||||
|
As += As_offset;
|
||||||
|
Bs += Bs_offset;
|
||||||
|
|
||||||
|
// Iterate over BK in blocks of 8
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short kk = 0; kk < BK; kk += 8) {
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Load elements from threadgroup A as simdgroup matrices
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < TM; i++) {
|
||||||
|
Asimd[i].thread_elements()[0] =
|
||||||
|
static_cast<AccumType>(As[i * simd_stride_a + 0]);
|
||||||
|
Asimd[i].thread_elements()[1] =
|
||||||
|
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
|
||||||
|
}
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Load elements from threadgroup B as simdgroup matrices
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < TN; j++) {
|
||||||
|
Bsimd[j].thread_elements()[0] =
|
||||||
|
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
|
||||||
|
Bsimd[j].thread_elements()[1] =
|
||||||
|
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
|
||||||
|
}
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Multiply and accumulate into result simdgroup matrices
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < TM; i++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < TN; j++) {
|
||||||
|
short j_serp = (i % 2) ? (TN - 1 - j) : j;
|
||||||
|
|
||||||
|
simdgroup_multiply_accumulate(
|
||||||
|
results[i * TN + j_serp],
|
||||||
|
Asimd[i],
|
||||||
|
Bsimd[j_serp],
|
||||||
|
results[i * TN + j_serp]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Progress to next simdgroup tile
|
||||||
|
As += tile_stride_a;
|
||||||
|
Bs += tile_stride_b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Store results from simdgroup_matrix results into device memory */
|
||||||
|
METAL_FUNC void store_result(device U* C, const int ldc) const {
|
||||||
|
// Adjust for simdgroup and thread location
|
||||||
|
C += (sm + tm) * ldc + tn + sn;
|
||||||
|
|
||||||
|
// Loop over all simdgroup tiles
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < TM; i++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < TN; j++) {
|
||||||
|
// Get accumulated result and associated offset in C
|
||||||
|
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||||
|
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||||
|
|
||||||
|
// Apply epilogue
|
||||||
|
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
|
||||||
|
|
||||||
|
// Write out C
|
||||||
|
C[offset] = outs[0];
|
||||||
|
C[offset + 1] = outs[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
METAL_FUNC void
|
||||||
|
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
|
||||||
|
// Adjust for simdgroup and thread location
|
||||||
|
C += (sm + tm) * ldc + (tn + sn);
|
||||||
|
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < TM; i++) {
|
||||||
|
if (i * TM_stride < dst_tile_dims.y) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < TN; j++) {
|
||||||
|
// Get accumulated result and associated offset in C
|
||||||
|
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||||
|
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||||
|
|
||||||
|
// Apply epilogue and output C
|
||||||
|
if (j * TN_stride < dst_tile_dims.x) {
|
||||||
|
C[offset] = Epilogue::apply(accum[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||||
|
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Store results from simdgroup_matrix results into device memory */
|
||||||
|
METAL_FUNC void store_result(
|
||||||
|
device U* D,
|
||||||
|
const int ldd,
|
||||||
|
const device U* C,
|
||||||
|
const int ldc,
|
||||||
|
const int fdc,
|
||||||
|
thread const Epilogue& epilogue_op) const {
|
||||||
|
// Adjust for simdgroup and thread location
|
||||||
|
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||||
|
D += (sm + tm) * ldd + tn + sn;
|
||||||
|
|
||||||
|
// Loop over all simdgroup tiles
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < TM; i++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < TN; j++) {
|
||||||
|
// Get accumulated result and associated offset in C
|
||||||
|
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||||
|
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||||
|
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||||
|
|
||||||
|
// Apply epilogue
|
||||||
|
U outs[2] = {
|
||||||
|
epilogue_op.apply(accum[0], C[offset_c]),
|
||||||
|
epilogue_op.apply(accum[1], C[offset_c + fdc])};
|
||||||
|
|
||||||
|
// Write out D
|
||||||
|
D[offset_d] = outs[0];
|
||||||
|
D[offset_d + 1] = outs[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
METAL_FUNC void store_result_safe(
|
||||||
|
device U* D,
|
||||||
|
const int ldd,
|
||||||
|
const device U* C,
|
||||||
|
const int ldc,
|
||||||
|
const int fdc,
|
||||||
|
short2 dst_tile_dims,
|
||||||
|
thread const Epilogue& epilogue_op) const {
|
||||||
|
// Adjust for simdgroup and thread location
|
||||||
|
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||||
|
D += (sm + tm) * ldd + tn + sn;
|
||||||
|
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < TM; i++) {
|
||||||
|
if (i * TM_stride < dst_tile_dims.y) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < TN; j++) {
|
||||||
|
// Get accumulated result and associated offset in C
|
||||||
|
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||||
|
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||||
|
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||||
|
|
||||||
|
// Apply epilogue and output C
|
||||||
|
if (j * TN_stride < dst_tile_dims.x) {
|
||||||
|
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||||
|
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace steel
|
||||||
|
} // namespace mlx
|
79
candle-metal-kernels/src/gemm/params.h
Normal file
79
candle-metal-kernels/src/gemm/params.h
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM param classes
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace mlx {
|
||||||
|
namespace steel {
|
||||||
|
|
||||||
|
struct GEMMParams {
|
||||||
|
const int M;
|
||||||
|
const int N;
|
||||||
|
const int K;
|
||||||
|
|
||||||
|
const int lda;
|
||||||
|
const int ldb;
|
||||||
|
const int ldc;
|
||||||
|
|
||||||
|
const int tiles_n;
|
||||||
|
const int tiles_m;
|
||||||
|
|
||||||
|
const int batch_stride_a;
|
||||||
|
const int batch_stride_b;
|
||||||
|
const int batch_stride_c;
|
||||||
|
|
||||||
|
const int swizzle_log;
|
||||||
|
const int gemm_k_iterations_aligned;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct GEMMSpiltKParams {
|
||||||
|
const int M;
|
||||||
|
const int N;
|
||||||
|
const int K;
|
||||||
|
|
||||||
|
const int lda;
|
||||||
|
const int ldb;
|
||||||
|
const int ldc;
|
||||||
|
|
||||||
|
const int tiles_n;
|
||||||
|
const int tiles_m;
|
||||||
|
|
||||||
|
const int split_k_partitions;
|
||||||
|
const int split_k_partition_stride;
|
||||||
|
const int split_k_partition_size;
|
||||||
|
|
||||||
|
const int gemm_k_iterations_aligned;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct GEMMAddMMParams {
|
||||||
|
const int M;
|
||||||
|
const int N;
|
||||||
|
const int K;
|
||||||
|
|
||||||
|
const int lda;
|
||||||
|
const int ldb;
|
||||||
|
const int ldc;
|
||||||
|
const int ldd;
|
||||||
|
|
||||||
|
const int tiles_n;
|
||||||
|
const int tiles_m;
|
||||||
|
|
||||||
|
const int batch_stride_a;
|
||||||
|
const int batch_stride_b;
|
||||||
|
const int batch_stride_c;
|
||||||
|
const int batch_stride_d;
|
||||||
|
|
||||||
|
const int swizzle_log;
|
||||||
|
const int gemm_k_iterations_aligned;
|
||||||
|
|
||||||
|
const float alpha;
|
||||||
|
const float beta;
|
||||||
|
|
||||||
|
const int fdc;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace steel
|
||||||
|
} // namespace mlx
|
BIN
candle-metal-kernels/src/gemm/steel_gemm.metallib
Normal file
BIN
candle-metal-kernels/src/gemm/steel_gemm.metallib
Normal file
Binary file not shown.
63
candle-metal-kernels/src/gemm/transforms.h
Normal file
63
candle-metal-kernels/src/gemm/transforms.h
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Transforms and Epilogues
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace mlx {
|
||||||
|
namespace steel {
|
||||||
|
|
||||||
|
template <typename OutT, typename InT>
|
||||||
|
struct TransformNone {
|
||||||
|
static METAL_FUNC OutT apply(InT x) {
|
||||||
|
return static_cast<OutT>(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
static METAL_FUNC OutT apply(InT x, OutT) {
|
||||||
|
return static_cast<OutT>(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OutT, typename InT>
|
||||||
|
struct TransformAdd {
|
||||||
|
TransformAdd(const float, const float) {}
|
||||||
|
|
||||||
|
static METAL_FUNC OutT apply(InT x, OutT c) {
|
||||||
|
return static_cast<OutT>(x) + c;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OutT, typename InT>
|
||||||
|
struct TransformAxpby {
|
||||||
|
const float alpha;
|
||||||
|
const float beta;
|
||||||
|
|
||||||
|
TransformAxpby(const float alpha_, const float beta_)
|
||||||
|
: alpha(alpha_), beta(beta_) {}
|
||||||
|
|
||||||
|
METAL_FUNC OutT apply(InT x, OutT c) const {
|
||||||
|
return static_cast<OutT>(x * alpha + (beta * c));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct AccumHelper {
|
||||||
|
typedef float accum_type;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BlockSwizzle {
|
||||||
|
static METAL_FUNC int2
|
||||||
|
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
|
||||||
|
const int tid_x = (tid.x) >> swizzle_log;
|
||||||
|
const int tid_y =
|
||||||
|
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
|
||||||
|
return int2(tid_x, tid_y);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace steel
|
||||||
|
} // namespace mlx
|
276
candle-metal-kernels/src/gemm/utils.h
Normal file
276
candle-metal-kernels/src/gemm/utils.h
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <metal_math>
|
||||||
|
#include "gemm/bf16.h"
|
||||||
|
#include "gemm/complex.h"
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Type limits utils
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
struct Limits {
|
||||||
|
static const constant U max = metal::numeric_limits<U>::max();
|
||||||
|
static const constant U min = metal::numeric_limits<U>::min();
|
||||||
|
static const constant U finite_max = metal::numeric_limits<U>::max();
|
||||||
|
static const constant U finite_min = metal::numeric_limits<U>::min();
|
||||||
|
};
|
||||||
|
|
||||||
|
#define instantiate_default_limit(type) \
|
||||||
|
template <> \
|
||||||
|
struct Limits<type> { \
|
||||||
|
static constexpr constant type max = metal::numeric_limits<type>::max(); \
|
||||||
|
static constexpr constant type min = metal::numeric_limits<type>::min(); \
|
||||||
|
static constexpr constant type finite_max = \
|
||||||
|
metal::numeric_limits<type>::max(); \
|
||||||
|
static constexpr constant type finite_min = \
|
||||||
|
metal::numeric_limits<type>::min(); \
|
||||||
|
};
|
||||||
|
|
||||||
|
instantiate_default_limit(uint8_t);
|
||||||
|
instantiate_default_limit(uint16_t);
|
||||||
|
instantiate_default_limit(uint32_t);
|
||||||
|
instantiate_default_limit(uint64_t);
|
||||||
|
instantiate_default_limit(int8_t);
|
||||||
|
instantiate_default_limit(int16_t);
|
||||||
|
instantiate_default_limit(int32_t);
|
||||||
|
instantiate_default_limit(int64_t);
|
||||||
|
|
||||||
|
#define instantiate_float_limit(type) \
|
||||||
|
template <> \
|
||||||
|
struct Limits<type> { \
|
||||||
|
static constexpr constant type max = \
|
||||||
|
metal::numeric_limits<type>::infinity(); \
|
||||||
|
static constexpr constant type min = \
|
||||||
|
-metal::numeric_limits<type>::infinity(); \
|
||||||
|
static constexpr constant type finite_max = \
|
||||||
|
metal::numeric_limits<type>::max(); \
|
||||||
|
static constexpr constant type finite_min = \
|
||||||
|
-metal::numeric_limits<type>::max(); \
|
||||||
|
};
|
||||||
|
|
||||||
|
instantiate_float_limit(half);
|
||||||
|
instantiate_float_limit(float);
|
||||||
|
instantiate_float_limit(bfloat16_t);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct Limits<bool> {
|
||||||
|
static constexpr constant bool max = true;
|
||||||
|
static constexpr constant bool min = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Indexing utils
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
inline size_t elem_to_loc(
|
||||||
|
uint elem,
|
||||||
|
device const int* shape,
|
||||||
|
device const size_t* strides,
|
||||||
|
int ndim) {
|
||||||
|
size_t loc = 0;
|
||||||
|
for (int i = ndim - 1; i >= 0; --i) {
|
||||||
|
loc += (elem % shape[i]) * strides[i];
|
||||||
|
elem /= shape[i];
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline size_t elem_to_loc(
|
||||||
|
uint elem,
|
||||||
|
constant const int* shape,
|
||||||
|
constant const size_t* strides,
|
||||||
|
int ndim) {
|
||||||
|
size_t loc = 0;
|
||||||
|
for (int i = ndim - 1; i >= 0; --i) {
|
||||||
|
loc += (elem % shape[i]) * strides[i];
|
||||||
|
elem /= shape[i];
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
inline uint2 elem_to_loc_2_nd(
|
||||||
|
uint3 elem,
|
||||||
|
constant const int shape[NDIM],
|
||||||
|
constant const size_t a_strides[NDIM],
|
||||||
|
constant const size_t b_strides[NDIM]) {
|
||||||
|
uint2 loc = {
|
||||||
|
static_cast<uint>(
|
||||||
|
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||||
|
static_cast<uint>(
|
||||||
|
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
|
||||||
|
for (int d = NDIM - 3; d >= 0; --d) {
|
||||||
|
uint l = elem.z % shape[d];
|
||||||
|
loc.x += l * a_strides[d];
|
||||||
|
loc.y += l * b_strides[d];
|
||||||
|
elem.z /= shape[d];
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
inline size_t elem_to_loc_nd(
|
||||||
|
uint3 elem,
|
||||||
|
constant const int shape[NDIM],
|
||||||
|
constant const size_t strides[NDIM]) {
|
||||||
|
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
|
||||||
|
for (int d = NDIM - 3; d >= 0; --d) {
|
||||||
|
loc += (elem.z % shape[d]) * strides[d];
|
||||||
|
elem.z /= shape[d];
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) {
|
||||||
|
return elem * stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
|
||||||
|
return elem.x * strides[1] + elem.y * strides[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
|
||||||
|
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non templated version to handle arbitrary dims
|
||||||
|
inline size_t elem_to_loc(
|
||||||
|
uint3 elem,
|
||||||
|
constant const int* shape,
|
||||||
|
constant const size_t* strides,
|
||||||
|
int ndim) {
|
||||||
|
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
||||||
|
for (int d = ndim - 3; d >= 0; --d) {
|
||||||
|
loc += (elem.z % shape[d]) * strides[d];
|
||||||
|
elem.z /= shape[d];
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint2 elem_to_loc_2_nd(
|
||||||
|
uint3 elem,
|
||||||
|
constant const int* shape,
|
||||||
|
constant const size_t* a_strides,
|
||||||
|
constant const size_t* b_strides,
|
||||||
|
int ndim) {
|
||||||
|
uint2 loc = {
|
||||||
|
static_cast<uint>(
|
||||||
|
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
||||||
|
static_cast<uint>(
|
||||||
|
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
||||||
|
for (int d = ndim - 3; d >= 0; --d) {
|
||||||
|
uint l = elem.z % shape[d];
|
||||||
|
loc.x += l * a_strides[d];
|
||||||
|
loc.y += l * b_strides[d];
|
||||||
|
elem.z /= shape[d];
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
inline uint elem_to_loc_nd(
|
||||||
|
uint elem,
|
||||||
|
device const int* shape,
|
||||||
|
device const size_t* strides);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline uint elem_to_loc_nd<1>(
|
||||||
|
uint elem,
|
||||||
|
device const int* shape,
|
||||||
|
device const size_t* strides) {
|
||||||
|
return (elem % shape[0]) * strides[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline uint elem_to_loc_nd<2>(
|
||||||
|
uint elem,
|
||||||
|
device const int* shape,
|
||||||
|
device const size_t* strides) {
|
||||||
|
uint loc = (elem % shape[1]) * strides[1];
|
||||||
|
elem /= shape[1];
|
||||||
|
loc += (elem % shape[0]) * strides[0];
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline uint elem_to_loc_nd<3>(
|
||||||
|
uint elem,
|
||||||
|
device const int* shape,
|
||||||
|
device const size_t* strides) {
|
||||||
|
uint loc = (elem % shape[2]) * strides[2];
|
||||||
|
elem /= shape[2];
|
||||||
|
loc += (elem % shape[1]) * strides[1];
|
||||||
|
elem /= shape[1];
|
||||||
|
loc += (elem % shape[0]) * strides[0];
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline uint elem_to_loc_nd<4>(
|
||||||
|
uint elem,
|
||||||
|
device const int* shape,
|
||||||
|
device const size_t* strides) {
|
||||||
|
uint loc = (elem % shape[3]) * strides[3];
|
||||||
|
elem /= shape[3];
|
||||||
|
loc += (elem % shape[2]) * strides[2];
|
||||||
|
elem /= shape[2];
|
||||||
|
loc += (elem % shape[1]) * strides[1];
|
||||||
|
elem /= shape[1];
|
||||||
|
loc += (elem % shape[0]) * strides[0];
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Calculation utils
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/** Compute ceil((float)N/(float)M) */
|
||||||
|
inline size_t ceildiv(size_t N, size_t M) {
|
||||||
|
return (N + M - 1) / M;
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
|
||||||
|
inline float log1p(float x) {
|
||||||
|
float xp1 = 1.0f + x;
|
||||||
|
if (xp1 == Limits<float>::max) {
|
||||||
|
return Limits<float>::max;
|
||||||
|
}
|
||||||
|
if (xp1 == 1.0f) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
return x * (metal::log(xp1) / (xp1 - 1.0f));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bfloat16_t log1p(bfloat16_t x) {
|
||||||
|
float xp1 = 1.0f + static_cast<float>(x);
|
||||||
|
if (xp1 == Limits<float>::max) {
|
||||||
|
return Limits<bfloat16_t>::max;
|
||||||
|
}
|
||||||
|
if (xp1 == 1.0f) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// SIMD shuffle ops
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
|
||||||
|
return as_type<uint64_t>(
|
||||||
|
metal::simd_shuffle_down(as_type<uint2>(data), delta));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
|
||||||
|
return as_type<int64_t>(
|
||||||
|
metal::simd_shuffle_down(as_type<uint2>(data), delta));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool simd_shuffle_down(bool data, uint16_t delta) {
|
||||||
|
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
|
||||||
|
}
|
9
candle-metal-kernels/src/gemm/utils2.h
Normal file
9
candle-metal-kernels/src/gemm/utils2.h
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <metal_stdlib>
|
||||||
|
#include "gemm/host.h"
|
||||||
|
|
||||||
|
#define STEEL_CONST static constant constexpr const
|
||||||
|
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
@ -1,6 +1,7 @@
|
|||||||
use metal::{
|
use metal::{
|
||||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLResourceOptions, MTLSize,
|
||||||
|
NSUInteger,
|
||||||
};
|
};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
@ -16,6 +17,7 @@ const CONV: &str = include_str!("conv.metal");
|
|||||||
const REDUCE: &str = include_str!("reduce.metal");
|
const REDUCE: &str = include_str!("reduce.metal");
|
||||||
const RANDOM: &str = include_str!("random.metal");
|
const RANDOM: &str = include_str!("random.metal");
|
||||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||||
|
const GEMM: &[u8] = include_bytes!("gemm/steel_gemm.metallib");
|
||||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||||
|
|
||||||
/// Most kernels apply similarly across the tensors
|
/// Most kernels apply similarly across the tensors
|
||||||
@ -122,6 +124,7 @@ pub enum Source {
|
|||||||
Cast,
|
Cast,
|
||||||
Reduce,
|
Reduce,
|
||||||
Mfa,
|
Mfa,
|
||||||
|
Gemm,
|
||||||
Conv,
|
Conv,
|
||||||
Random,
|
Random,
|
||||||
Quantized,
|
Quantized,
|
||||||
@ -248,6 +251,7 @@ impl Kernels {
|
|||||||
Source::Random => RANDOM,
|
Source::Random => RANDOM,
|
||||||
Source::Quantized => QUANTIZED,
|
Source::Quantized => QUANTIZED,
|
||||||
Source::Mfa => panic!("Invalid lib"),
|
Source::Mfa => panic!("Invalid lib"),
|
||||||
|
Source::Gemm => panic!("Invalid lib"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -271,6 +275,14 @@ impl Kernels {
|
|||||||
))
|
))
|
||||||
})?
|
})?
|
||||||
}
|
}
|
||||||
|
Source::Gemm => {
|
||||||
|
let source_data = GEMM;
|
||||||
|
device.new_library_with_data(source_data).map_err(|e| {
|
||||||
|
MetalKernelError::LoadLibraryError(format!(
|
||||||
|
"Candle metal requires macosx > 13.0 or higher, cannot load GEMM: {e}"
|
||||||
|
))
|
||||||
|
})?
|
||||||
|
}
|
||||||
source => {
|
source => {
|
||||||
let source_content = self.get_library_source(source);
|
let source_content = self.get_library_source(source);
|
||||||
device
|
device
|
||||||
@ -1230,6 +1242,34 @@ impl ConstantValues {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn string_to_static_str(s: String) -> &'static str {
|
||||||
|
Box::leak(s.into_boxed_str())
|
||||||
|
}
|
||||||
|
|
||||||
|
use core::ffi::c_int;
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct GEMMParams {
|
||||||
|
m: c_int,
|
||||||
|
n: c_int,
|
||||||
|
k: c_int,
|
||||||
|
|
||||||
|
lda: c_int,
|
||||||
|
ldb: c_int,
|
||||||
|
ldc: c_int,
|
||||||
|
|
||||||
|
tiles_n: c_int,
|
||||||
|
tiles_m: c_int,
|
||||||
|
|
||||||
|
batch_stride_a: c_int,
|
||||||
|
batch_stride_b: c_int,
|
||||||
|
batch_stride_c: c_int,
|
||||||
|
|
||||||
|
swizzle_log: c_int,
|
||||||
|
gemm_k_iterations_aligned: c_int,
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_gemm(
|
pub fn call_gemm(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
@ -1251,10 +1291,10 @@ pub fn call_gemm(
|
|||||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
let a_trans = if lhs_m1 == 1 && lhs_m2 == k {
|
let (a_trans, lda) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||||
false
|
(false, k as c_int)
|
||||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||||
true
|
(true, n as c_int)
|
||||||
} else {
|
} else {
|
||||||
return Err(MetalKernelError::MatMulNonContiguous {
|
return Err(MetalKernelError::MatMulNonContiguous {
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
@ -1262,10 +1302,10 @@ pub fn call_gemm(
|
|||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?;
|
})?;
|
||||||
};
|
};
|
||||||
let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
|
let (b_trans, ldb) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||||
false
|
(false, n as c_int)
|
||||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
true
|
(true, k as c_int)
|
||||||
} else {
|
} else {
|
||||||
return Err(MetalKernelError::MatMulNonContiguous {
|
return Err(MetalKernelError::MatMulNonContiguous {
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
@ -1273,119 +1313,195 @@ pub fn call_gemm(
|
|||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?;
|
})?;
|
||||||
};
|
};
|
||||||
let d_trans = false;
|
// let d_trans = false;
|
||||||
let alpha = 1.0f32;
|
// let alpha = 1.0f32;
|
||||||
let beta = 0.0f32;
|
// let beta = 0.0f32;
|
||||||
let batched = b > 1;
|
// let batched = b > 1;
|
||||||
let fused_activation = false;
|
// let fused_activation = false;
|
||||||
let fused_bias = false;
|
// let fused_bias = false;
|
||||||
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
// let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
||||||
let m_simd = 8;
|
// let m_simd = 8;
|
||||||
let n_simd = 8;
|
// let n_simd = 8;
|
||||||
let k_simd = 64;
|
// let k_simd = 64;
|
||||||
let m_splits = 1;
|
// let m_splits = 1;
|
||||||
let n_splits = 1;
|
// let n_splits = 1;
|
||||||
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
// (m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||||
} else {
|
// } else {
|
||||||
let m_simd = 40;
|
// let m_simd = 40;
|
||||||
let n_simd = 40;
|
// let n_simd = 40;
|
||||||
let k_simd = 32;
|
// let k_simd = 32;
|
||||||
let m_splits = 1;
|
// let m_splits = 1;
|
||||||
let n_splits = 1;
|
// let n_splits = 1;
|
||||||
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
// (m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||||
};
|
// };
|
||||||
let constants = Some(ConstantValues::new(vec![
|
// let constants = Some(ConstantValues::new(vec![
|
||||||
(0, Value::USize(m)),
|
// (0, Value::USize(m)),
|
||||||
(1, Value::USize(n)),
|
// (1, Value::USize(n)),
|
||||||
(2, Value::USize(k)),
|
// (2, Value::USize(k)),
|
||||||
(10, Value::Bool(a_trans)),
|
// (10, Value::Bool(a_trans)),
|
||||||
(11, Value::Bool(b_trans)),
|
// (11, Value::Bool(b_trans)),
|
||||||
(13, Value::Bool(d_trans)),
|
// (13, Value::Bool(d_trans)),
|
||||||
(20, Value::F32(alpha)),
|
// (20, Value::F32(alpha)),
|
||||||
(21, Value::F32(beta)),
|
// (21, Value::F32(beta)),
|
||||||
(100, Value::Bool(batched)),
|
// (100, Value::Bool(batched)),
|
||||||
(101, Value::Bool(fused_activation)),
|
// (101, Value::Bool(fused_activation)),
|
||||||
// Garbage
|
// // Garbage
|
||||||
(102, Value::Bool(false)),
|
// (102, Value::Bool(false)),
|
||||||
(103, Value::Bool(false)),
|
// (103, Value::Bool(false)),
|
||||||
(113, Value::Bool(false)),
|
// (113, Value::Bool(false)),
|
||||||
(50_000, Value::Bool(false)),
|
// (50_000, Value::Bool(false)),
|
||||||
// End garbage
|
// // End garbage
|
||||||
(200, Value::U16(m_simd)),
|
// (200, Value::U16(m_simd)),
|
||||||
(201, Value::U16(n_simd)),
|
// (201, Value::U16(n_simd)),
|
||||||
(202, Value::U16(k_simd)),
|
// (202, Value::U16(k_simd)),
|
||||||
(210, Value::U16(m_splits)),
|
// (210, Value::U16(m_splits)),
|
||||||
(211, Value::U16(n_splits)),
|
// (211, Value::U16(n_splits)),
|
||||||
(50_001, Value::Bool(fused_bias)),
|
// (50_001, Value::Bool(fused_bias)),
|
||||||
]));
|
// ]));
|
||||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
|
let a_trans_name = if a_trans { "t" } else { "n" };
|
||||||
let m_group = m_simd * m_splits;
|
let b_trans_name = if b_trans { "t" } else { "n" };
|
||||||
let n_group = n_simd * n_splits;
|
let (iname, oname) = match name {
|
||||||
|
"sgemm" => ("float32", "float32"),
|
||||||
let a_block_length = m_group * k_simd;
|
"hgemm" => ("float16", "float16"),
|
||||||
let b_block_length = k_simd * n_group;
|
"bgemm" => ("bfloat16", "bfloat16"),
|
||||||
|
|
||||||
let mut block_elements = a_block_length + b_block_length;
|
|
||||||
if (m % 8 != 0) && (n % 8 != 0) {
|
|
||||||
let c_block_length = m_group * n_group;
|
|
||||||
block_elements = std::cmp::max(c_block_length, block_elements)
|
|
||||||
}
|
|
||||||
if fused_bias {
|
|
||||||
if d_trans {
|
|
||||||
block_elements = std::cmp::max(block_elements, m_group);
|
|
||||||
} else {
|
|
||||||
block_elements = std::cmp::max(block_elements, n_group);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let bytes = match name {
|
|
||||||
"sgemm" => 4,
|
|
||||||
"hgemm" => 2,
|
|
||||||
other => {
|
other => {
|
||||||
return Err(MetalKernelError::LoadLibraryError(format!(
|
return Err(MetalKernelError::LoadLibraryError(format!(
|
||||||
"{other} is not a valid kernel for gemm"
|
"{other} is not a valid kernel for gemm"
|
||||||
)));
|
)))
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let block_bytes = block_elements * bytes;
|
let mut bm = 32;
|
||||||
|
let mut bn = 32;
|
||||||
|
let mut bk = 16;
|
||||||
|
let wm = 2;
|
||||||
|
let wn = 2;
|
||||||
|
if b * m * n >= 1 << 20 {
|
||||||
|
if !a_trans && b_trans {
|
||||||
|
bm = 64;
|
||||||
|
bn = if oname == "float32" { 64 } else { 32 };
|
||||||
|
bk = if oname == "float32" { 16 } else { 32 };
|
||||||
|
} else {
|
||||||
|
bm = 64;
|
||||||
|
bn = 64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mnaligned = if m % bm == 0 && n % bn == 0 {
|
||||||
|
"taligned"
|
||||||
|
} else {
|
||||||
|
"naligned"
|
||||||
|
};
|
||||||
|
let kaligned = if k % bk == 0 { "taligned" } else { "naligned" };
|
||||||
|
// let bytes = match &name[..] {
|
||||||
|
// "sgemm" => 4,
|
||||||
|
// "hgemm" => 2,
|
||||||
|
// other => {
|
||||||
|
// return Err(MetalKernelError::LoadLibraryError(format!(
|
||||||
|
// "{other} is not a valid kernel for gemm"
|
||||||
|
// )));
|
||||||
|
// }
|
||||||
|
// };
|
||||||
|
let name = format!("steel_gemm_{a_trans_name}{b_trans_name}_{iname}_{oname}_bm{bm}_bn{bn}_bk{bk}_wm{wm}_wn{wn}_MN_{mnaligned}_K_{kaligned}");
|
||||||
|
let name = string_to_static_str(name);
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::Gemm, name)?;
|
||||||
|
// let m_group = m_simd * m_splits;
|
||||||
|
// let n_group = n_simd * n_splits;
|
||||||
|
|
||||||
|
// let a_block_length = m_group * k_simd;
|
||||||
|
// let b_block_length = k_simd * n_group;
|
||||||
|
|
||||||
|
// let mut block_elements = a_block_length + b_block_length;
|
||||||
|
// if (m % 8 != 0) && (n % 8 != 0) {
|
||||||
|
// let c_block_length = m_group * n_group;
|
||||||
|
// block_elements = std::cmp::max(c_block_length, block_elements)
|
||||||
|
// }
|
||||||
|
// if fused_bias {
|
||||||
|
// if d_trans {
|
||||||
|
// block_elements = std::cmp::max(block_elements, m_group);
|
||||||
|
// } else {
|
||||||
|
// block_elements = std::cmp::max(block_elements, n_group);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// let block_bytes = block_elements * bytes;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
// encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||||
|
|
||||||
|
let batch_stride_a: i32 = if lhs_stride.len() > 2 {
|
||||||
|
lhs_stride[lhs_stride.len() - 3] as i32
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
let batch_stride_b: i32 = if rhs_stride.len() > 2 {
|
||||||
|
rhs_stride[rhs_stride.len() - 3] as i32
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
let batch_stride_c = (m * n) as i32;
|
||||||
|
|
||||||
|
let swizzle_log = 0;
|
||||||
|
let tiles_n = ((n + bn - 1) / bn) as c_int;
|
||||||
|
let tiles_m = ((m + bm - 1) / bm) as c_int;
|
||||||
|
|
||||||
|
let params = GEMMParams {
|
||||||
|
m: m as c_int,
|
||||||
|
n: n as c_int,
|
||||||
|
k: k as c_int,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
ldc: n as c_int,
|
||||||
|
tiles_m,
|
||||||
|
tiles_n,
|
||||||
|
batch_stride_a,
|
||||||
|
batch_stride_b,
|
||||||
|
batch_stride_c,
|
||||||
|
swizzle_log,
|
||||||
|
gemm_k_iterations_aligned: (k / bk) as c_int,
|
||||||
|
};
|
||||||
|
let params_buffer = device.new_buffer_with_data(
|
||||||
|
¶ms as *const GEMMParams as *const c_void,
|
||||||
|
core::mem::size_of::<GEMMParams>() as u64,
|
||||||
|
MTLResourceOptions::StorageModeShared,
|
||||||
|
);
|
||||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||||
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
||||||
encoder.set_buffer(2, Some(output), 0);
|
encoder.set_buffer(2, Some(output), 0);
|
||||||
|
encoder.set_buffer(3, Some(¶ms_buffer), 0);
|
||||||
// TODO Tensor D
|
// TODO Tensor D
|
||||||
|
|
||||||
let grid_z = b;
|
let grid_z = b;
|
||||||
if batched {
|
// if batched {
|
||||||
let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
|
// let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
|
||||||
let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
|
// let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
|
||||||
let byte_stride_c = m * n * bytes as usize;
|
// let byte_stride_c = m * n * bytes as usize;
|
||||||
// TODO byte_stride_d
|
// // TODO byte_stride_d
|
||||||
let byte_stride_d = 0;
|
// let byte_stride_d = 0;
|
||||||
|
|
||||||
let buffer: Vec<u64> = vec![
|
// let buffer: Vec<u64> = vec![
|
||||||
byte_stride_a as _,
|
// byte_stride_a as _,
|
||||||
byte_stride_b as _,
|
// byte_stride_b as _,
|
||||||
byte_stride_c as _,
|
// byte_stride_c as _,
|
||||||
byte_stride_d as _,
|
// byte_stride_d as _,
|
||||||
];
|
// ];
|
||||||
encoder.set_bytes(
|
// // encoder.set_bytes(
|
||||||
10,
|
// // 10,
|
||||||
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
// // (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
||||||
buffer.as_ptr() as *const NSUInteger as *const c_void,
|
// // buffer.as_ptr() as *const NSUInteger as *const c_void,
|
||||||
);
|
// // );
|
||||||
}
|
// }
|
||||||
|
let tile = 1 << swizzle_log;
|
||||||
|
let tm = (tiles_m + tile - 1) / tile;
|
||||||
|
let tn = tiles_n * tile;
|
||||||
|
|
||||||
let grid_size = MTLSize {
|
let grid_size = MTLSize {
|
||||||
width: divide(n, n_group.into()),
|
width: tn as u64,
|
||||||
height: divide(m, m_group.into()),
|
height: tm as u64,
|
||||||
depth: grid_z as NSUInteger,
|
depth: grid_z as NSUInteger,
|
||||||
};
|
};
|
||||||
let group_size = MTLSize {
|
let group_size = MTLSize {
|
||||||
width: 32 * (m_splits as u64) * (n_splits as u64),
|
width: 32,
|
||||||
height: 1,
|
height: wn,
|
||||||
depth: 1,
|
depth: wm,
|
||||||
};
|
};
|
||||||
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||||
|
Reference in New Issue
Block a user