Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)

* update flash-attn v1

* restore: hdim224

* add 224 flash_fwd_template

* remove whitespace
This commit is contained in:
Michael Feil
2024-12-31 09:32:22 +01:00
committed by GitHub
parent d60eba1408
commit 71cd6d5533
41 changed files with 140 additions and 83 deletions

View File

@ -390,4 +390,22 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
#pragma unroll
for (int i = 0; i < size(tensor); ++i) {
tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void calculate_dtanh(Tensor<Engine0, Layout0> &src_tensor, Tensor<Engine1, Layout1> &dst_tensor, const float softcap){
#pragma unroll
for (int i = 0; i < size(src_tensor); ++i) {
dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash