Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)

* update flash-attn v1

* restore: hdim224

* add 224 flash_fwd_template

* remove whitespace
This commit is contained in:
Michael Feil
2024-12-31 09:32:22 +01:00
committed by GitHub
parent d60eba1408
commit 71cd6d5533
41 changed files with 140 additions and 83 deletions

View File

@ -0,0 +1,42 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <tuple>
#include <cstdio>
#if !defined(__CUDACC_RTC__)
#include "cuda_runtime.h"
#endif
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, \
cudaGetErrorString(status_)); \
exit(1); \
} \
} while (0)
inline int get_current_device() {
int device;
CHECK_CUDA(cudaGetDevice(&device));
return device;
}
inline std::tuple<int, int> get_compute_capability(int device) {
int capability_major, capability_minor;
CHECK_CUDA(cudaDeviceGetAttribute(&capability_major, cudaDevAttrComputeCapabilityMajor, device));
CHECK_CUDA(cudaDeviceGetAttribute(&capability_minor, cudaDevAttrComputeCapabilityMinor, device));
return {capability_major, capability_minor};
}
inline int get_num_sm(int device) {
int multiprocessor_count;
CHECK_CUDA(cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device));
return multiprocessor_count;
}