mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00

* update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace
43 lines
1.6 KiB
C++
43 lines
1.6 KiB
C++
/******************************************************************************
|
|
* Copyright (c) 2024, Tri Dao.
|
|
******************************************************************************/
|
|
|
|
#pragma once
|
|
|
|
#include <tuple>
|
|
#include <cstdio>
|
|
|
|
#if !defined(__CUDACC_RTC__)
|
|
#include "cuda_runtime.h"
|
|
#endif
|
|
|
|
#define CHECK_CUDA(call) \
|
|
do { \
|
|
cudaError_t status_ = call; \
|
|
if (status_ != cudaSuccess) { \
|
|
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, \
|
|
cudaGetErrorString(status_)); \
|
|
exit(1); \
|
|
} \
|
|
} while (0)
|
|
|
|
|
|
inline int get_current_device() {
|
|
int device;
|
|
CHECK_CUDA(cudaGetDevice(&device));
|
|
return device;
|
|
}
|
|
|
|
inline std::tuple<int, int> get_compute_capability(int device) {
|
|
int capability_major, capability_minor;
|
|
CHECK_CUDA(cudaDeviceGetAttribute(&capability_major, cudaDevAttrComputeCapabilityMajor, device));
|
|
CHECK_CUDA(cudaDeviceGetAttribute(&capability_minor, cudaDevAttrComputeCapabilityMinor, device));
|
|
return {capability_major, capability_minor};
|
|
}
|
|
|
|
inline int get_num_sm(int device) {
|
|
int multiprocessor_count;
|
|
CHECK_CUDA(cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device));
|
|
return multiprocessor_count;
|
|
}
|