Rework the cuda casting bits. (#1112)

This commit is contained in:
Laurent Mazare
2023-10-17 09:44:51 +01:00
committed by GitHub
parent 00948eb656
commit 2fe24ac5b1

View File

@ -1,6 +1,53 @@
#include "cuda_utils.cuh"
#include<stdint.h>
template <typename S, typename T>
__device__ void cast_(
const size_t numel,
const size_t num_dims,
const size_t *info,
const S *inp,
T *out
) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
if (is_contiguous(num_dims, dims, strides)) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
out[i] = inp[i];
}
}
else {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
out[i] = inp[strided_i];
}
}
}
template <typename S, typename T, typename I>
__device__ void cast_through(
const size_t numel,
const size_t num_dims,
const size_t *info,
const S *inp,
T *out
) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
if (is_contiguous(num_dims, dims, strides)) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
out[i] = static_cast<T>(static_cast<I>(inp[i]));
}
}
else {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
out[i] = static_cast<T>(static_cast<I>(inp[strided_i]));
}
}
}
#define CAST_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
@ -9,22 +56,10 @@ extern "C" __global__ void FN_NAME( \
const SRC_TYPENAME *inp, \
DST_TYPENAME *out \
) { \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
if (is_contiguous(num_dims, dims, strides)) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
out[i] = inp[i]; \
} \
} \
else { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
out[i] = inp[strided_i]; \
} \
} \
cast_<SRC_TYPENAME, DST_TYPENAME>(numel, num_dims, info, inp, out); \
} \
#define CAST_BF_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \
#define CAST_THROUGH_OP(SRC_TYPENAME, DST_TYPENAME, INT_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
const size_t num_dims, \
@ -32,25 +67,12 @@ extern "C" __global__ void FN_NAME( \
const SRC_TYPENAME *inp, \
DST_TYPENAME *out \
) { \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
if (is_contiguous(num_dims, dims, strides)) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
out[i] = (DST_TYPENAME) (float) inp[i]; \
} \
} \
else { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
out[i] = (DST_TYPENAME) (float) inp[strided_i]; \
} \
} \
cast_through<SRC_TYPENAME, DST_TYPENAME, INT_TYPENAME>(numel, num_dims, info, inp, out); \
} \
#if __CUDA_ARCH__ >= 800
CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16)
// CAST_OP(__nv_bfloat16, uint8_t, cast_bf16_u8)
CAST_OP(__nv_bfloat16, uint32_t, cast_bf16_u32)
CAST_OP(__nv_bfloat16, float, cast_bf16_f32)
CAST_OP(__nv_bfloat16, double, cast_bf16_f64)
@ -58,14 +80,15 @@ CAST_OP(uint8_t, __nv_bfloat16, cast_u8_bf16)
CAST_OP(uint32_t, __nv_bfloat16, cast_u32_bf16)
CAST_OP(float, __nv_bfloat16, cast_f32_bf16)
CAST_OP(double, __nv_bfloat16, cast_f64_bf16)
CAST_BF_OP(__nv_bfloat16, __half, cast_bf16_f16)
CAST_BF_OP(__half, __nv_bfloat16, cast_f16_bf16)
CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8)
CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16)
CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16)
#endif
#if __CUDA_ARCH__ >= 530
CAST_OP(__half, __half, cast_f16_f16)
// CAST_OP(__half, uint8_t, cast_f16_u8 )
CAST_THROUGH_OP(__half, uint8_t, float, cast_f16_u8)
CAST_OP(__half, uint32_t, cast_f16_u32)
CAST_OP(__half, float, cast_f16_f32)
CAST_OP(__half, double, cast_f16_f64)