Update the flash attn kernels. (#2333)

This commit is contained in:
Laurent Mazare
2024-07-15 20:37:36 +02:00
committed by GitHub
parent d74fbed334
commit 30cdd769f9
51 changed files with 2279 additions and 904 deletions

View File

@ -14,6 +14,7 @@
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
@ -25,6 +26,56 @@
} \
}()
#ifdef FLASHATTENTION_DISABLE_DROPOUT
#define DROPOUT_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define DROPOUT_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_ALIBI
#define ALIBI_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define ALIBI_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
#define EVENK_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
}()
#else
#define EVENK_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
#define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define SOFTCAP_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_LOCAL
#define LOCAL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define LOCAL_SWITCH BOOL_SWITCH
#endif
#define FP16_SWITCH(COND, ...) \
[&] { \
if (COND) { \
@ -36,7 +87,7 @@
} \
}()
#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \
#define HEADDIM_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM <= 32) { \
constexpr static int kHeadDim = 32; \