mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)
* update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace
This commit is contained in:
@ -7,13 +7,7 @@
|
||||
#include <cuda.h>
|
||||
#include <vector>
|
||||
|
||||
// #ifdef OLD_GENERATOR_PATH
|
||||
// #include <ATen/CUDAGeneratorImpl.h>
|
||||
// #else
|
||||
// #include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
// #endif
|
||||
//
|
||||
// #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
|
||||
// #include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
|
||||
|
||||
constexpr int TOTAL_DIM = 0;
|
||||
constexpr int H_DIM = 1;
|
||||
@ -76,6 +70,7 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
// array of length b+1 holding starting offset of each sequence.
|
||||
int * __restrict__ cu_seqlens_q;
|
||||
int * __restrict__ cu_seqlens_k;
|
||||
int * __restrict__ leftpad_k;
|
||||
|
||||
// If provided, the actual length of each k sequence.
|
||||
int * __restrict__ seqused_k;
|
||||
@ -189,6 +184,6 @@ struct Flash_bwd_params : public Flash_fwd_params {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
// template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream);
|
||||
// template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream);
|
||||
|
Reference in New Issue
Block a user