mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)
* update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace
This commit is contained in:
@ -3,11 +3,11 @@
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
// #include <ATen/cuda/CUDAContext.h>
|
||||
// #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
|
||||
#include "error.h"
|
||||
#include "static_switch.h"
|
||||
#include "hardware_info.h"
|
||||
#include "flash.h"
|
||||
#include "flash_fwd_kernel.h"
|
||||
|
||||
@ -74,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout>;
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
||||
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
||||
@ -205,7 +205,8 @@ inline bool cuda_is_sm8x() {
|
||||
template<typename T, bool Is_causal>
|
||||
void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 96;
|
||||
bool is_sm8x = cuda_is_sm8x();
|
||||
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
|
||||
bool is_sm8x = cc_major == 8 && cc_minor > 0;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
if (is_sm8x) {
|
||||
@ -228,7 +229,8 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
template<typename T, bool Is_causal>
|
||||
void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 128;
|
||||
bool is_sm8x = cuda_is_sm8x();
|
||||
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
|
||||
bool is_sm8x = cc_major == 8 && cc_minor > 0;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
@ -262,7 +264,8 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
template<typename T, bool Is_causal>
|
||||
void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 160;
|
||||
bool is_sm8x = cuda_is_sm8x();
|
||||
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
|
||||
bool is_sm8x = cc_major == 8 && cc_minor > 0;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// For A100, H100, 128 x 32 is the fastest.
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
|
Reference in New Issue
Block a user