mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)
* update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace
This commit is contained in:
@ -390,4 +390,22 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tensor); ++i) {
|
||||
tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void calculate_dtanh(Tensor<Engine0, Layout0> &src_tensor, Tensor<Engine1, Layout1> &dst_tensor, const float softcap){
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(src_tensor); ++i) {
|
||||
dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
||||
|
Reference in New Issue
Block a user