mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Update the flash attn kernels. (#2333)
This commit is contained in:
@ -14,8 +14,7 @@
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
#include <cute/algorithm/gemm.hpp>
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
@ -29,10 +28,10 @@ namespace flash {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
inline __device__ uint32_t relu2(const uint32_t x);
|
||||
__forceinline__ __device__ uint32_t relu2(const uint32_t x);
|
||||
|
||||
template<>
|
||||
inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
|
||||
__forceinline__ __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
|
||||
uint32_t res;
|
||||
const uint32_t zero = 0u;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
@ -50,7 +49,7 @@ inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
template<>
|
||||
inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
|
||||
__forceinline__ __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
|
||||
uint32_t res;
|
||||
const uint32_t zero = 0u;
|
||||
asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
|
||||
@ -63,10 +62,10 @@ inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
|
||||
template<typename T>
|
||||
inline __device__ uint32_t convert_relu2(const float2 x);
|
||||
__forceinline__ __device__ uint32_t convert_relu2(const float2 x);
|
||||
|
||||
template<>
|
||||
inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
|
||||
__forceinline__ __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
|
||||
uint32_t res;
|
||||
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
|
||||
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
|
||||
@ -75,7 +74,7 @@ inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
|
||||
__forceinline__ __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
|
||||
uint32_t res;
|
||||
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
|
||||
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
|
||||
@ -89,20 +88,20 @@ inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
|
||||
|
||||
template<typename T>
|
||||
struct MaxOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxOp<float> {
|
||||
// This is slightly faster
|
||||
__device__ inline float operator()(float const &x, float const &y) { return max(x, y); }
|
||||
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -111,7 +110,7 @@ template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
static __device__ __forceinline__ T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
@ -123,7 +122,7 @@ struct Allreduce {
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
static __device__ __forceinline__ T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
@ -135,7 +134,7 @@ template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename
|
||||
typename Tensor2, typename Tensor3, typename Tensor4,
|
||||
typename TiledMma, typename TiledCopyA, typename TiledCopyB,
|
||||
typename ThrCopyA, typename ThrCopyB>
|
||||
inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
|
||||
__forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
|
||||
Tensor4 const& tCsB, TiledMma tiled_mma,
|
||||
TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
|
||||
ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
|
||||
@ -162,9 +161,9 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
|
||||
|
||||
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
|
||||
typename TiledMma, typename TiledCopy, typename ThrCopy>
|
||||
inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
|
||||
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
|
||||
ThrCopy smem_thr_copy_B) {
|
||||
__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
|
||||
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
|
||||
ThrCopy smem_thr_copy_B) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
|
||||
@ -184,42 +183,48 @@ inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB
|
||||
|
||||
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
template<typename Layout>
|
||||
inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
||||
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
|
||||
// TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
|
||||
// "int_tuple.hpp(74): error: conversion to inaccessible base class"
|
||||
// return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
||||
return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l)));
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
|
||||
template<typename MMA_traits, typename Layout>
|
||||
inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
|
||||
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
|
||||
using X = Underscore;
|
||||
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
|
||||
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
|
||||
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
|
||||
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
|
||||
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
|
||||
// TD [2023-08-13]: Same error as above on Cutlass 3.2
|
||||
// return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
|
||||
// get<0, 1>(l),
|
||||
// get<1, 1, 1>(l));
|
||||
return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))),
|
||||
get<1>(get<0>(l)),
|
||||
get<1>(get<1>(get<1>(l))));
|
||||
if constexpr (mma_shape_K == 8) {
|
||||
return acc_layout;
|
||||
} else {
|
||||
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
|
||||
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
||||
template<typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) {
|
||||
using X = Underscore;
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
|
||||
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
@ -231,7 +236,7 @@ inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void relu_(Tensor<Engine, Layout> &tensor) {
|
||||
__forceinline__ __device__ void relu_(Tensor<Engine, Layout> &tensor) {
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
static_assert(numel % 2 == 0);
|
||||
using value_t = typename Engine::value_type;
|
||||
@ -247,7 +252,7 @@ inline __device__ void relu_(Tensor<Engine, Layout> &tensor) {
|
||||
|
||||
// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
inline __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {
|
||||
__forceinline__ __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>);
|
||||
static_assert(std::is_same_v<float, From_type>);
|
||||
@ -289,7 +294,7 @@ void cp_async_wait() {
|
||||
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
|
||||
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
|
||||
__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
|
||||
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
@ -355,4 +360,34 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_even_K=true,
|
||||
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
__forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
|
||||
Tensor<Engine3, Layout3> const &predicate_K,
|
||||
const int max_MN=0, const int min_MN=0) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
|
||||
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || predicate_K(k)) {
|
||||
cute::copy(S(_, m, k), D(_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
||||
|
Reference in New Issue
Block a user