Again set a few extra params in flash-attn. (#245)

* Again set a few extra params.

* Use the appropriate kernel sizes.

* Add all the kernel sizes.

* Parallel compiling.

* Reduce the amount of parallelism.

* Add the missing kernel.

* Fix a typo.

* Remove bf16 support for now.
This commit is contained in:
Laurent Mazare
2023-07-26 14:16:37 +01:00
committed by GitHub
parent fa2b64d678
commit 2ce5f12513
21 changed files with 476 additions and 116 deletions

View File

@ -0,0 +1,109 @@
#include "flash_fwd_launch_template.h"
// TODO: Switch back to handling bf16.
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
FWD_HEADDIM_SWITCH(params.d, [&] {
run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
});
}
// void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// FP16_SWITCH(!params.is_bf16, [&] {
// FWD_HEADDIM_SWITCH(params.d, [&] {
// run_mha_fwd_<elem_type, kHeadDim>(params, stream);
// });
// });
// }
extern "C" void run_mha(
void *q_ptr,
void *k_ptr,
void *v_ptr,
void *o_ptr,
void *softmax_lse_ptr,
uint32_t q_batch_stride,
uint32_t k_batch_stride,
uint32_t v_batch_stride,
uint32_t o_batch_stride,
uint32_t q_row_stride,
uint32_t k_row_stride,
uint32_t v_row_stride,
uint32_t o_row_stride,
uint32_t q_head_stride,
uint32_t k_head_stride,
uint32_t v_head_stride,
uint32_t o_head_stride,
uint32_t b,
uint32_t h,
uint32_t h_k,
uint32_t d,
uint32_t d_rounded,
float softmax_scale,
uint32_t seqlen_q,
uint32_t seqlen_k,
uint32_t seqlen_q_rounded,
uint32_t seqlen_k_rounded,
int is_causal
) {
Flash_fwd_params params;
// Reset the parameters
memset(&params, 0, sizeof(params));
// Set the pointers and strides.
params.q_ptr = q_ptr;
params.k_ptr = k_ptr;
params.v_ptr = v_ptr;
params.o_ptr = o_ptr;
params.softmax_lse_ptr = softmax_lse_ptr;
// All stride are in elements, not bytes.
params.q_batch_stride = q_batch_stride;
params.k_batch_stride = k_batch_stride;
params.v_batch_stride = v_batch_stride;
params.o_batch_stride = o_batch_stride;
params.q_row_stride = q_row_stride;
params.k_row_stride = k_row_stride;
params.v_row_stride = v_row_stride;
params.o_row_stride = o_row_stride;
params.q_head_stride = q_head_stride;
params.k_head_stride = k_head_stride;
params.v_head_stride = v_head_stride;
params.o_head_stride = o_head_stride;
// Set the dimensions.
params.b = b;
params.h = h;
params.h_k = h_k;
params.h_h_k_ratio = h / h_k;
params.seqlen_q = seqlen_q;
params.seqlen_k = seqlen_k;
params.seqlen_q_rounded = seqlen_q_rounded;
params.seqlen_k_rounded = seqlen_k_rounded;
params.d = d;
params.d_rounded = d_rounded;
params.is_causal = is_causal;
// Set the different scale values.
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
params.p_dropout = 1.; // probability to keep
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
params.rp_dropout = 1.f / params.p_dropout;
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
params.is_bf16 = 0;
params.cu_seqlens_q = nullptr;
params.cu_seqlens_k = nullptr;
params.p_ptr = nullptr;
cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd(params, stream);
}

View File

@ -0,0 +1,19 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// if (params.p_dropout == 1.f) {
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
// }
// }
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
}

View File

@ -0,0 +1,32 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// if (params.p_dropout == 1.f) {
// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, true, elem_type>, false>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, true, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 64, 4, false, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 128, 4, false, false, elem_type>, false>(params, stream);
// // 1st ones are good for H100, A100
// // 2nd one is good for A6000 bc we get slightly better occupancy
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, false, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, true, elem_type>, true>(params, stream);
// // 1st one is good for H100, A100, A6000
// }
// }
template<>
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
}

View File

@ -0,0 +1,17 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// });
// }
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
}

View File

@ -0,0 +1,27 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, true, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 128, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 8, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 128, 8, false, elem_type>>(params, stream);
// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest.
// // For A100, H100, 1st is fastest.
// });
// }
template<>
void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
}

View File

@ -0,0 +1,16 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// });
// }
template<> void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
}

View File

@ -0,0 +1,27 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// // This one is slightly faster for causal?
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 8, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 128, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 128, 8, false, elem_type>>(params, stream);
// });
// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout
// // For A6000, 1st is faster when causal, 3rd is faster when not causal
// }
template<>
void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
}

View File

@ -0,0 +1,9 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
}

View File

@ -0,0 +1,9 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
}

View File

@ -0,0 +1,9 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
}

View File

@ -0,0 +1,9 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
}

View File

@ -0,0 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream);
}

View File

@ -20,94 +20,4 @@
template<>
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
}
extern "C" void run_mha(
void *q_ptr,
void *k_ptr,
void *v_ptr,
void *o_ptr,
void *softmax_lse_ptr,
uint32_t q_batch_stride,
uint32_t k_batch_stride,
uint32_t v_batch_stride,
uint32_t o_batch_stride,
uint32_t q_row_stride,
uint32_t k_row_stride,
uint32_t v_row_stride,
uint32_t o_row_stride,
uint32_t q_head_stride,
uint32_t k_head_stride,
uint32_t v_head_stride,
uint32_t o_head_stride,
uint32_t b,
uint32_t h,
uint32_t h_k,
uint32_t d,
uint32_t d_rounded,
float softmax_scale,
uint32_t seqlen_q,
uint32_t seqlen_k,
uint32_t seqlen_q_rounded,
uint32_t seqlen_k_rounded,
int is_causal
) {
Flash_fwd_params params;
// Reset the parameters
memset(&params, 0, sizeof(params));
// Set the pointers and strides.
params.q_ptr = q_ptr;
params.k_ptr = k_ptr;
params.v_ptr = v_ptr;
params.o_ptr = o_ptr;
params.softmax_lse_ptr = softmax_lse_ptr;
// All stride are in elements, not bytes.
params.q_batch_stride = q_batch_stride;
params.k_batch_stride = k_batch_stride;
params.v_batch_stride = v_batch_stride;
params.o_batch_stride = o_batch_stride;
params.q_row_stride = q_row_stride;
params.k_row_stride = k_row_stride;
params.v_row_stride = v_row_stride;
params.o_row_stride = o_row_stride;
params.q_head_stride = q_head_stride;
params.k_head_stride = k_head_stride;
params.v_head_stride = v_head_stride;
params.o_head_stride = o_head_stride;
// Set the dimensions.
params.b = b;
params.h = h;
params.h_k = h_k;
params.h_h_k_ratio = h / h_k;
params.seqlen_q = seqlen_q;
params.seqlen_k = seqlen_k;
params.seqlen_q_rounded = seqlen_q_rounded;
params.seqlen_k_rounded = seqlen_k_rounded;
params.d = d;
params.d_rounded = d_rounded;
params.is_causal = is_causal;
// Set the different scale values.
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
params.p_dropout = 1.; // probability to keep
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
params.rp_dropout = 1.f / params.p_dropout;
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd_<cutlass::half_t, 32>(params, stream);
}
}

View File

@ -0,0 +1,19 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// if (params.p_dropout == 1.f) {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
// }
// }
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
}

View File

@ -0,0 +1,26 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// if (params.p_dropout == 1.f) {
// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// // Using block size (64 x 256) is 27% slower for seqlen=2k
// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 128, 4, false, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, true>(params, stream);
// }
// }
template<>
void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
}

View File

@ -0,0 +1,17 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
// });
// }
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream);
}

View File

@ -0,0 +1,23 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, true, elem_type>, Is_dropout>(params, stream);
// // This 3rd one is good for H100, and A100, A6000
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, true, elem_type>, Is_dropout>(params, stream);
// // These two are always slower
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, elem_type>>(params, stream);
// });
// }
template<> void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::half_t>(params, stream);
}