mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)
* update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace
This commit is contained in:
@ -54,6 +54,7 @@ fn main() -> Result<()> {
|
||||
println!("cargo:rerun-if-changed=kernels/kernel_traits.h");
|
||||
println!("cargo:rerun-if-changed=kernels/block_info.h");
|
||||
println!("cargo:rerun-if-changed=kernels/static_switch.h");
|
||||
println!("cargo:rerun-if-changed=kernels/hardware_info.h");
|
||||
let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?);
|
||||
let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") {
|
||||
Err(_) =>
|
||||
|
Submodule candle-flash-attn/cutlass updated: 7d49e6c7e2...4c42f73fda
@ -18,8 +18,9 @@ struct BlockInfo {
|
||||
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
|
||||
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
|
||||
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
||||
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
|
||||
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
|
||||
, leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])
|
||||
, seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k)
|
||||
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
|
||||
{
|
||||
}
|
||||
|
||||
@ -30,13 +31,14 @@ struct BlockInfo {
|
||||
|
||||
template <typename index_t>
|
||||
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
|
||||
return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;
|
||||
}
|
||||
|
||||
const int sum_s_q;
|
||||
const int sum_s_k;
|
||||
const int actual_seqlen_q;
|
||||
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
|
||||
const int leftpad_k;
|
||||
const int seqlen_k_cache;
|
||||
const int actual_seqlen_k;
|
||||
};
|
||||
|
@ -7,13 +7,7 @@
|
||||
#include <cuda.h>
|
||||
#include <vector>
|
||||
|
||||
// #ifdef OLD_GENERATOR_PATH
|
||||
// #include <ATen/CUDAGeneratorImpl.h>
|
||||
// #else
|
||||
// #include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
// #endif
|
||||
//
|
||||
// #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
|
||||
// #include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
|
||||
|
||||
constexpr int TOTAL_DIM = 0;
|
||||
constexpr int H_DIM = 1;
|
||||
@ -76,6 +70,7 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
// array of length b+1 holding starting offset of each sequence.
|
||||
int * __restrict__ cu_seqlens_q;
|
||||
int * __restrict__ cu_seqlens_k;
|
||||
int * __restrict__ leftpad_k;
|
||||
|
||||
// If provided, the actual length of each k sequence.
|
||||
int * __restrict__ seqused_k;
|
||||
@ -189,6 +184,6 @@ struct Flash_bwd_params : public Flash_fwd_params {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
// template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream);
|
||||
// template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream);
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
|
@ -4,6 +4,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// #include "philox_unpack.cuh" // For at::cuda::philox::unpack
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
@ -22,14 +24,6 @@ namespace flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tensor); ++i) {
|
||||
tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
|
||||
@ -328,7 +322,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
);
|
||||
// if (cute::thread0()) { print(acc_s); }
|
||||
if constexpr (Is_softcap){
|
||||
apply_softcap(acc_s, params.softcap);
|
||||
flash::apply_softcap(acc_s, params.softcap);
|
||||
}
|
||||
|
||||
mask.template apply_mask<Is_causal, Is_even_MN>(
|
||||
@ -394,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
smem_thr_copy_Q, smem_thr_copy_K
|
||||
);
|
||||
if constexpr (Is_softcap){
|
||||
apply_softcap(acc_s, params.softcap);
|
||||
flash::apply_softcap(acc_s, params.softcap);
|
||||
}
|
||||
|
||||
flash::cp_async_wait<0>();
|
||||
@ -691,7 +685,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
|
||||
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
|
||||
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
|
||||
const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
|
||||
const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2);
|
||||
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
|
||||
make_stride(params.rotary_dim / 2, _1{}));
|
||||
@ -712,9 +706,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
// if (cute::thread(8, 0)) { print_tensor(gCos); }
|
||||
// if (cute::thread(0, 0)) { print_tensor(tRgCos); }
|
||||
|
||||
const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
|
||||
// const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
|
||||
const index_t row_offset_knew = bidb * params.knew_batch_stride
|
||||
+ ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
|
||||
const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
|
||||
// const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
|
||||
const index_t row_offset_vnew = bidb * params.vnew_batch_stride
|
||||
+ ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
|
||||
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
|
||||
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
|
||||
@ -792,7 +788,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
binfo.actual_seqlen_q - m_block * kBlockM);
|
||||
} else {
|
||||
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
|
||||
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
|
||||
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
|
||||
// We do this by setting the row stride of gCos / gSin to 0.
|
||||
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
|
||||
@ -886,7 +882,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
);
|
||||
// if (cute::thread0()) { print(acc_s); }
|
||||
if constexpr (Is_softcap){
|
||||
apply_softcap(acc_s, params.softcap);
|
||||
flash::apply_softcap(acc_s, params.softcap);
|
||||
}
|
||||
|
||||
|
||||
@ -961,7 +957,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
smem_thr_copy_Q, smem_thr_copy_K
|
||||
);
|
||||
if constexpr (Is_softcap){
|
||||
apply_softcap(acc_s, params.softcap);
|
||||
flash::apply_softcap(acc_s, params.softcap);
|
||||
}
|
||||
|
||||
flash::cp_async_wait<0>();
|
||||
@ -1226,7 +1222,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
|
||||
constexpr int kBlockN = kNThreads / kBlockM;
|
||||
using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
|
||||
using GmemTiledCopyOaccum = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
GmemLayoutAtomOaccum{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||
GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
|
||||
|
@ -3,11 +3,11 @@
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
// #include <ATen/cuda/CUDAContext.h>
|
||||
// #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
|
||||
#include "error.h"
|
||||
#include "static_switch.h"
|
||||
#include "hardware_info.h"
|
||||
#include "flash.h"
|
||||
#include "flash_fwd_kernel.h"
|
||||
|
||||
@ -74,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout>;
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
||||
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
||||
@ -205,7 +205,8 @@ inline bool cuda_is_sm8x() {
|
||||
template<typename T, bool Is_causal>
|
||||
void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 96;
|
||||
bool is_sm8x = cuda_is_sm8x();
|
||||
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
|
||||
bool is_sm8x = cc_major == 8 && cc_minor > 0;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
if (is_sm8x) {
|
||||
@ -228,7 +229,8 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
template<typename T, bool Is_causal>
|
||||
void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 128;
|
||||
bool is_sm8x = cuda_is_sm8x();
|
||||
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
|
||||
bool is_sm8x = cc_major == 8 && cc_minor > 0;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
@ -262,7 +264,8 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
template<typename T, bool Is_causal>
|
||||
void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 160;
|
||||
bool is_sm8x = cuda_is_sm8x();
|
||||
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
|
||||
bool is_sm8x = cc_major == 8 && cc_minor > 0;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// For A100, H100, 128 x 32 is the fastest.
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
|
42
candle-flash-attn/kernels/hardware_info.h
Normal file
42
candle-flash-attn/kernels/hardware_info.h
Normal file
@ -0,0 +1,42 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
#include <cstdio>
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include "cuda_runtime.h"
|
||||
#endif
|
||||
|
||||
#define CHECK_CUDA(call) \
|
||||
do { \
|
||||
cudaError_t status_ = call; \
|
||||
if (status_ != cudaSuccess) { \
|
||||
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, \
|
||||
cudaGetErrorString(status_)); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
inline int get_current_device() {
|
||||
int device;
|
||||
CHECK_CUDA(cudaGetDevice(&device));
|
||||
return device;
|
||||
}
|
||||
|
||||
inline std::tuple<int, int> get_compute_capability(int device) {
|
||||
int capability_major, capability_minor;
|
||||
CHECK_CUDA(cudaDeviceGetAttribute(&capability_major, cudaDevAttrComputeCapabilityMajor, device));
|
||||
CHECK_CUDA(cudaDeviceGetAttribute(&capability_minor, cudaDevAttrComputeCapabilityMinor, device));
|
||||
return {capability_major, capability_minor};
|
||||
}
|
||||
|
||||
inline int get_num_sm(int device) {
|
||||
int multiprocessor_count;
|
||||
CHECK_CUDA(cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device));
|
||||
return multiprocessor_count;
|
||||
}
|
@ -101,8 +101,8 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
using SmemLayoutO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomO{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
|
||||
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
|
||||
using SmemCopyAtomO = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>;
|
||||
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
|
||||
|
||||
static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
|
||||
@ -125,14 +125,14 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
using Gmem_copy_struct = std::conditional_t<
|
||||
Has_cp_async,
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||
DefaultCopy
|
||||
AutoVectorizingCopyWithAssumedAlignment<128>
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopyO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
|
||||
@ -144,7 +144,7 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
Stride< _16, _1>>
|
||||
>;
|
||||
using GmemTiledCopyOaccum = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
GmemLayoutAtomOaccum{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
|
||||
@ -153,7 +153,7 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
GmemLayoutAtomRotcossin{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
|
||||
using GmemTiledCopyRotcossinCont = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
||||
GmemLayoutAtomRotcossin{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
|
||||
};
|
||||
@ -250,7 +250,7 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
|
||||
using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
|
||||
|
||||
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
|
||||
using SmemCopyAtomPdS = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
|
||||
|
||||
using SmemLayoutQdOtransposed = decltype(
|
||||
composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
|
||||
@ -263,7 +263,7 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
using SmemLayoutdKV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomdKV{},
|
||||
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
|
||||
using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>;
|
||||
using SmemCopyAtomdKV = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
|
||||
|
||||
using SmemLayoutAtomdQ = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
@ -272,7 +272,7 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
using SmemLayoutdQ = decltype(tile_to_shape(
|
||||
SmemLayoutAtomdQ{},
|
||||
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
|
||||
using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
|
||||
using SmemCopyAtomdQ = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
|
||||
|
||||
// Double buffer for sQ
|
||||
static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
|
||||
@ -303,22 +303,22 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
using Gmem_copy_struct = std::conditional_t<
|
||||
Has_cp_async,
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||
DefaultCopy
|
||||
AutoVectorizingCopyWithAssumedAlignment<128>
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopydO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemTiledCopydKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemTiledCopydQ = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemLayoutAtomdQaccum = std::conditional_t<
|
||||
@ -329,12 +329,12 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
Stride< _16, _1>>
|
||||
>;
|
||||
using GmemTiledCopydQaccum = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
GmemLayoutAtomdQaccum{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||
|
||||
using GmemTiledCopydQaccumAtomicAdd = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
|
||||
Stride<_32, _1>>{},
|
||||
Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store
|
||||
|
@ -390,4 +390,22 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tensor); ++i) {
|
||||
tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void calculate_dtanh(Tensor<Engine0, Layout0> &src_tensor, Tensor<Engine1, Layout1> &dst_tensor, const float softcap){
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(src_tensor); ++i) {
|
||||
dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
||||
|
Reference in New Issue
Block a user