mirror of
https://github.com/huggingface/candle.git
synced 2025-06-23 04:46:15 +00:00
Update the flash attn kernels. (#2333)
This commit is contained in:
@ -14,6 +14,7 @@
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
@ -25,6 +26,56 @@
|
||||
} \
|
||||
}()
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_DROPOUT
|
||||
#define DROPOUT_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
constexpr static bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
}()
|
||||
#else
|
||||
#define DROPOUT_SWITCH BOOL_SWITCH
|
||||
#endif
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_ALIBI
|
||||
#define ALIBI_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
constexpr static bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
}()
|
||||
#else
|
||||
#define ALIBI_SWITCH BOOL_SWITCH
|
||||
#endif
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
|
||||
#define EVENK_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
constexpr static bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
}()
|
||||
#else
|
||||
#define EVENK_SWITCH BOOL_SWITCH
|
||||
#endif
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
|
||||
#define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
constexpr static bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
}()
|
||||
#else
|
||||
#define SOFTCAP_SWITCH BOOL_SWITCH
|
||||
#endif
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_LOCAL
|
||||
#define LOCAL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
constexpr static bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
}()
|
||||
#else
|
||||
#define LOCAL_SWITCH BOOL_SWITCH
|
||||
#endif
|
||||
|
||||
#define FP16_SWITCH(COND, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
@ -36,7 +87,7 @@
|
||||
} \
|
||||
}()
|
||||
|
||||
#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \
|
||||
#define HEADDIM_SWITCH(HEADDIM, ...) \
|
||||
[&] { \
|
||||
if (HEADDIM <= 32) { \
|
||||
constexpr static int kHeadDim = 32; \
|
||||
|
Reference in New Issue
Block a user