mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)
* update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace
This commit is contained in:
@ -101,8 +101,8 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
using SmemLayoutO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomO{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
|
||||
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
|
||||
using SmemCopyAtomO = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>;
|
||||
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
|
||||
|
||||
static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
|
||||
@ -125,14 +125,14 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
using Gmem_copy_struct = std::conditional_t<
|
||||
Has_cp_async,
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||
DefaultCopy
|
||||
AutoVectorizingCopyWithAssumedAlignment<128>
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopyO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
|
||||
@ -144,7 +144,7 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
Stride< _16, _1>>
|
||||
>;
|
||||
using GmemTiledCopyOaccum = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
GmemLayoutAtomOaccum{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
|
||||
@ -153,7 +153,7 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
GmemLayoutAtomRotcossin{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
|
||||
using GmemTiledCopyRotcossinCont = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
||||
GmemLayoutAtomRotcossin{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
|
||||
};
|
||||
@ -250,7 +250,7 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
|
||||
using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
|
||||
|
||||
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
|
||||
using SmemCopyAtomPdS = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
|
||||
|
||||
using SmemLayoutQdOtransposed = decltype(
|
||||
composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
|
||||
@ -263,7 +263,7 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
using SmemLayoutdKV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomdKV{},
|
||||
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
|
||||
using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>;
|
||||
using SmemCopyAtomdKV = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
|
||||
|
||||
using SmemLayoutAtomdQ = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
@ -272,7 +272,7 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
using SmemLayoutdQ = decltype(tile_to_shape(
|
||||
SmemLayoutAtomdQ{},
|
||||
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
|
||||
using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
|
||||
using SmemCopyAtomdQ = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
|
||||
|
||||
// Double buffer for sQ
|
||||
static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
|
||||
@ -303,22 +303,22 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
using Gmem_copy_struct = std::conditional_t<
|
||||
Has_cp_async,
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||
DefaultCopy
|
||||
AutoVectorizingCopyWithAssumedAlignment<128>
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopydO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemTiledCopydKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemTiledCopydQ = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemLayoutAtomdQaccum = std::conditional_t<
|
||||
@ -329,12 +329,12 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
Stride< _16, _1>>
|
||||
>;
|
||||
using GmemTiledCopydQaccum = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
GmemLayoutAtomdQaccum{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||
|
||||
using GmemTiledCopydQaccumAtomicAdd = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
|
||||
Stride<_32, _1>>{},
|
||||
Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store
|
||||
|
Reference in New Issue
Block a user