mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)
* update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace
This commit is contained in:
@ -4,6 +4,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// #include "philox_unpack.cuh" // For at::cuda::philox::unpack
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
@ -22,14 +24,6 @@ namespace flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tensor); ++i) {
|
||||
tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
|
||||
@ -328,7 +322,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
);
|
||||
// if (cute::thread0()) { print(acc_s); }
|
||||
if constexpr (Is_softcap){
|
||||
apply_softcap(acc_s, params.softcap);
|
||||
flash::apply_softcap(acc_s, params.softcap);
|
||||
}
|
||||
|
||||
mask.template apply_mask<Is_causal, Is_even_MN>(
|
||||
@ -394,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
smem_thr_copy_Q, smem_thr_copy_K
|
||||
);
|
||||
if constexpr (Is_softcap){
|
||||
apply_softcap(acc_s, params.softcap);
|
||||
flash::apply_softcap(acc_s, params.softcap);
|
||||
}
|
||||
|
||||
flash::cp_async_wait<0>();
|
||||
@ -691,7 +685,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
|
||||
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
|
||||
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
|
||||
const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
|
||||
const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2);
|
||||
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
|
||||
make_stride(params.rotary_dim / 2, _1{}));
|
||||
@ -712,9 +706,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
// if (cute::thread(8, 0)) { print_tensor(gCos); }
|
||||
// if (cute::thread(0, 0)) { print_tensor(tRgCos); }
|
||||
|
||||
const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
|
||||
// const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
|
||||
const index_t row_offset_knew = bidb * params.knew_batch_stride
|
||||
+ ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
|
||||
const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
|
||||
// const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
|
||||
const index_t row_offset_vnew = bidb * params.vnew_batch_stride
|
||||
+ ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
|
||||
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
|
||||
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
|
||||
@ -792,7 +788,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
binfo.actual_seqlen_q - m_block * kBlockM);
|
||||
} else {
|
||||
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
|
||||
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
|
||||
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
|
||||
// We do this by setting the row stride of gCos / gSin to 0.
|
||||
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
|
||||
@ -886,7 +882,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
);
|
||||
// if (cute::thread0()) { print(acc_s); }
|
||||
if constexpr (Is_softcap){
|
||||
apply_softcap(acc_s, params.softcap);
|
||||
flash::apply_softcap(acc_s, params.softcap);
|
||||
}
|
||||
|
||||
|
||||
@ -961,7 +957,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
smem_thr_copy_Q, smem_thr_copy_K
|
||||
);
|
||||
if constexpr (Is_softcap){
|
||||
apply_softcap(acc_s, params.softcap);
|
||||
flash::apply_softcap(acc_s, params.softcap);
|
||||
}
|
||||
|
||||
flash::cp_async_wait<0>();
|
||||
@ -1226,7 +1222,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
|
||||
constexpr int kBlockN = kNThreads / kBlockM;
|
||||
using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
|
||||
using GmemTiledCopyOaccum = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
GmemLayoutAtomOaccum{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||
GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
|
||||
|
Reference in New Issue
Block a user