Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)

* update flash-attn v1

* restore: hdim224

* add 224 flash_fwd_template

* remove whitespace
This commit is contained in:
Michael Feil
2024-12-31 09:32:22 +01:00
committed by GitHub
parent d60eba1408
commit 71cd6d5533
41 changed files with 140 additions and 83 deletions

View File

@ -101,8 +101,8 @@ struct Flash_fwd_kernel_traits : public Base {
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
using SmemCopyAtomO = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>;
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
@ -125,14 +125,14 @@ struct Flash_fwd_kernel_traits : public Base {
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
DefaultCopy
AutoVectorizingCopyWithAssumedAlignment<128>
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
@ -144,7 +144,7 @@ struct Flash_fwd_kernel_traits : public Base {
Stride< _16, _1>>
>;
using GmemTiledCopyOaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
@ -153,7 +153,7 @@ struct Flash_fwd_kernel_traits : public Base {
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
using GmemTiledCopyRotcossinCont = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
};
@ -250,7 +250,7 @@ struct Flash_bwd_kernel_traits : public Base {
composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtomPdS = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
using SmemLayoutQdOtransposed = decltype(
composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
@ -263,7 +263,7 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutdKV = decltype(tile_to_shape(
SmemLayoutAtomdKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtomdKV = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
using SmemLayoutAtomdQ = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
@ -272,7 +272,7 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutdQ = decltype(tile_to_shape(
SmemLayoutAtomdQ{},
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtomdQ = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
// Double buffer for sQ
static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
@ -303,22 +303,22 @@ struct Flash_bwd_kernel_traits : public Base {
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
DefaultCopy
AutoVectorizingCopyWithAssumedAlignment<128>
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopydO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemTiledCopydKV = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemTiledCopydQ = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemLayoutAtomdQaccum = std::conditional_t<
@ -329,12 +329,12 @@ struct Flash_bwd_kernel_traits : public Base {
Stride< _16, _1>>
>;
using GmemTiledCopydQaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
GmemLayoutAtomdQaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using GmemTiledCopydQaccumAtomicAdd = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
Stride<_32, _1>>{},
Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store