mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
chore: update flash attention kernels (#1518)
* chore: update flash attention kernels * fmt * remove unused kernels * force f32 * correct stride
This commit is contained in:
@ -4,15 +4,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// #include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "static_switch.h"
|
||||
#include "flash.h"
|
||||
#include "flash_fwd_kernel.h"
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax>
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
|
||||
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
|
||||
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params);
|
||||
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
|
||||
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
|
||||
@ -26,35 +25,39 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
dim3 grid(num_m_block, params.b, params.h);
|
||||
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
|
||||
// for cu_seqlens_q as well.
|
||||
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
|
||||
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
const bool return_softmax = params.p_ptr != nullptr;
|
||||
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
|
||||
// Will only return softmax if dropout, to reduce compilation time.
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
|
||||
// if (smem_size >= 48 * 1024) {
|
||||
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
// }
|
||||
int ctas_per_sm;
|
||||
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
// C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
||||
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
|
||||
BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
// Will only return softmax if dropout, to reduce compilation time.
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
||||
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
||||
// int ctas_per_sm;
|
||||
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 32;
|
||||
constexpr static int Headdim = 32;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
@ -64,7 +67,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 64;
|
||||
constexpr static int Headdim = 64;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
@ -86,7 +89,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 96;
|
||||
constexpr static int Headdim = 96;
|
||||
// auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
@ -112,7 +115,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 128;
|
||||
constexpr static int Headdim = 128;
|
||||
// auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
@ -149,7 +152,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 160;
|
||||
constexpr static int Headdim = 160;
|
||||
// auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
@ -179,7 +182,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 192;
|
||||
constexpr static int Headdim = 192;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
@ -198,7 +201,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 224;
|
||||
constexpr static int Headdim = 224;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
@ -224,7 +227,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 256;
|
||||
constexpr static int Headdim = 256;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_sm, max_smem_per_block;
|
||||
|
Reference in New Issue
Block a user