/****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once #include "cute/algorithm/copy.hpp" #include "cutlass/cutlass.h" #include "cutlass/layout/layout.h" #include using namespace cute; template struct Flash_kernel_traits { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 using Element = elem_type; static constexpr bool Has_cp_async = true; #else using Element = cutlass::half_t; static constexpr bool Has_cp_async = false; #endif using ElementAccum = float; using index_t = uint32_t; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 using MMA_Atom_Arch = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using ValLayoutMNK = Layout>; #else using MMA_Atom_Arch = MMA_Atom; using ValLayoutMNK = Layout>; #endif #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 using SmemCopyAtom = Copy_Atom; using SmemCopyAtomTransposed = Copy_Atom; #else using SmemCopyAtom = Copy_Atom; using SmemCopyAtomTransposed = Copy_Atom; #endif }; // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true template > struct Flash_fwd_kernel_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 32; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM using SmemLayoutAtomQ = decltype( composition(Swizzle{}, // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomVtransposed = decltype( composition(Swizzle{}, // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 Layout, Int>, Stride<_1, Int>>{})); using SmemLayoutVtransposed = decltype(tile_to_shape( SmemLayoutAtomVtransposed{}, Shape, Int>{})); // Maybe the VtransposeNoSwizzle just needs to have the right shape // And the strides don't matter? using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; static constexpr int kSmemQCount = size(SmemLayoutQ{}); static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. // For example, for d=128, smem is split into 2 "pages", each page takes care of columns // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, // to the same banks. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading // from the same address by the same threadblock. This is slightly faster. using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); using GmemLayoutAtomP = Layout, Int>, Stride, _1>>; using GmemTiledCopyP = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomP{}, Layout>{})); // Val layout, 8 vals per store }; // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. // No_double_buffer is another option to reduce smem usage, but will slow things down. template > struct Flash_bwd_kernel_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Is_V_in_regs = Is_V_in_regs_; static constexpr bool No_double_buffer = No_double_buffer_; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 32; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; static_assert(kNWarps % AtomLayoutMSdP == 0); static_assert(kNWarps % AtomLayoutNdKV == 0); static_assert(kNWarps % AtomLayoutMdQ == 0); using TiledMmaSdP = TiledMMA< typename Base::MMA_Atom_Arch, Layout, Int, _1>>, typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM using TiledMmadKV = TiledMMA< typename Base::MMA_Atom_Arch, Layout, Int, _1>>, typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM using TiledMmadQ = TiledMMA< typename Base::MMA_Atom_Arch, Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM using SmemLayoutAtomQdO = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQdO = decltype(tile_to_shape( SmemLayoutAtomQdO{}, make_shape(Int{}, Int{}))); using SmemLayoutAtomKV = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutKV = decltype(tile_to_shape( // SmemLayoutAtomQdO{}, SmemLayoutAtomKV{}, make_shape(Int{}, Int{}))); using SmemLayoutAtomKtransposed = decltype( composition(Swizzle{}, Layout, Int>, Stride<_1, Int>>{})); using SmemLayoutKtransposed = decltype(tile_to_shape( SmemLayoutAtomKtransposed{}, make_shape(Int{}, Int{}))); // Maybe the KtransposeNoSwizzle just needs to have the right shape // And the strides don't matter? using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 // static constexpr int kPBlockN = kBlockN; static_assert(kBlockN >= 64); // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. static constexpr int kPBlockN = 64; static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); static constexpr int kSwizzlePdS = 3; using SmemLayoutAtomPdS = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, make_shape(Int{}, Int{}))); using SmemLayoutAtomPdStransposed = decltype( composition(Swizzle{}, Layout, Int>, Stride<_1, Int>>{})); using SmemLayoutPdStransposed = decltype(tile_to_shape( SmemLayoutAtomPdStransposed{}, make_shape(Int{}, Int{}))); using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); using SmemCopyAtomPdS = Copy_Atom; using SmemLayoutAtomQdOtransposed = decltype( composition(Swizzle{}, Layout, Int>, Stride<_1, Int>>{})); using SmemLayoutQdOtransposed = decltype(tile_to_shape( SmemLayoutAtomQdOtransposed{}, make_shape(Int{}, Int{}))); using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); using SmemLayoutAtomdKV = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutdKV = decltype(tile_to_shape( SmemLayoutAtomdKV{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdKV = Copy_Atom; using SmemLayoutAtomdQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutdQ = decltype(tile_to_shape( SmemLayoutAtomdQ{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdQ = Copy_Atom; static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; static constexpr int kSmemdSCount = size(SmemLayoutPdS{}); static constexpr int kSmemPCount = size(SmemLayoutPdS{}); static constexpr int kSmemdQCount = size(SmemLayoutdQ{}); static constexpr int kSmemdPsumCount = kBlockM; static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); static constexpr int kSmemSize1colblock = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + kSmemPSize : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3 + kSmemdSSize + kSmemPSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem // to affect speed in practice. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading // from the same address by the same threadblock. This is slightly faster. using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopydO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemTiledCopydKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemTiledCopydQ = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomdQaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopydQaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomdQaccum{}, Layout>{})); // Val layout, 4 vals per store using GmemTiledCopydQaccumAtomicAdd = decltype( make_tiled_copy(Copy_Atom{}, Layout, // Thread layout, 8 threads per row Stride<_32, _1>>{}, Layout>{})); // Val layout, 1 val per store }; ////////////////////////////////////////////////////////////////////////////////////////////////////