// Copyright (c) 2023, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. #include "flash_fwd_launch_template.h" // template<> // void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { // using elem_type = cutlass::half_t; // BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // run_flash_fwd, Is_dropout>(params, stream); // // For dropout there might be a lot of register spilling? // // These two are very slow due to register spilling // // run_flash_fwd>(params, stream); // // run_flash_fwd>(params, stream); // // This one is slightly slower // // run_flash_fwd>(params, stream); // }); // } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); } extern "C" void run_mha( void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, uint32_t q_batch_stride, uint32_t k_batch_stride, uint32_t v_batch_stride, uint32_t q_row_stride, uint32_t k_row_stride, uint32_t v_row_stride, uint32_t q_head_stride, uint32_t k_head_stride, uint32_t v_head_stride, uint32_t b, uint32_t h, uint32_t h_k, uint32_t d, uint32_t d_rounded, float softmax_scale, uint32_t seqlen_q, uint32_t seqlen_k, uint32_t seqlen_q_rounded, uint32_t seqlen_k_rounded, int is_causal ) { Flash_fwd_params params; // Reset the parameters memset(¶ms, 0, sizeof(params)); // Set the pointers and strides. params.q_ptr = q_ptr; params.k_ptr = k_ptr; params.v_ptr = v_ptr; // All stride are in elements, not bytes. params.q_row_stride = q_row_stride; params.k_row_stride = k_row_stride; params.v_row_stride = v_row_stride; params.q_head_stride = q_head_stride; params.k_head_stride = k_head_stride; params.v_head_stride = v_head_stride; params.o_ptr = o_ptr; // Set the dimensions. params.b = b; params.h = h; params.h_k = h_k; params.h_h_k_ratio = h / h_k; params.seqlen_q = seqlen_q; params.seqlen_k = seqlen_k; params.seqlen_q_rounded = seqlen_q_rounded; params.seqlen_k_rounded = seqlen_k_rounded; params.d = d; params.d_rounded = d_rounded; params.is_causal = is_causal; // Set the different scale values. params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; cudaStream_t stream = 0; // Use the default stream. run_mha_fwd_(params, stream); }