diff --git a/candle-kernels/src/affine.cu b/candle-kernels/src/affine.cu index 3ce5b8a7..a52dd639 100644 --- a/candle-kernels/src/affine.cu +++ b/candle-kernels/src/affine.cu @@ -28,6 +28,10 @@ extern "C" __global__ void FN_NAME( \ } \ } \ +#if __CUDA_ARCH__ >= 800 +AFFINE_OP(__nv_bfloat16, affine_bf16) +#endif + #if __CUDA_ARCH__ >= 530 AFFINE_OP(__half, affine_f16) #endif diff --git a/candle-kernels/src/binary.cu b/candle-kernels/src/binary.cu index d8758a5e..65f24db1 100644 --- a/candle-kernels/src/binary.cu +++ b/candle-kernels/src/binary.cu @@ -1,6 +1,13 @@ #include "binary_op_macros.cuh" #include +#if __CUDA_ARCH__ >= 800 +BINARY_OP(__nv_bfloat16, badd_bf16, x + y) +BINARY_OP(__nv_bfloat16, bdiv_bf16, x / y) +BINARY_OP(__nv_bfloat16, bmul_bf16, x * y) +BINARY_OP(__nv_bfloat16, bsub_bf16, x - y) +#endif + #if __CUDA_ARCH__ >= 530 BINARY_OP(__half, badd_f16, x + y) BINARY_OP(__half, bdiv_f16, x / y) diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index 3b64de61..817823c0 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -24,6 +24,19 @@ extern "C" __global__ void FN_NAME( \ } \ } \ +#if __CUDA_ARCH__ >= 800 +CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16) + +CAST_OP(__nv_bfloat16, uint32_t, cast_bf16_u32) +CAST_OP(__nv_bfloat16, __half, cast_bf16_f16) +CAST_OP(__nv_bfloat16, float, cast_bf16_f32) +CAST_OP(__nv_bfloat16, double, cast_bf16_f64) +CAST_OP(uint32_t, __nv_bfloat16, cast_u32_bf16) +CAST_OP(__half, __nv_bfloat16, cast_f16_bf16) +CAST_OP(float, __nv_bfloat16, cast_f32_bf16) +CAST_OP(double, __nv_bfloat16, cast_f64_bf16) +#endif + #if __CUDA_ARCH__ >= 530 CAST_OP(__half, __half, cast_f16_f16) diff --git a/candle-kernels/src/compatibility.cuh b/candle-kernels/src/compatibility.cuh index 69e34764..2df8e921 100644 --- a/candle-kernels/src/compatibility.cuh +++ b/candle-kernels/src/compatibility.cuh @@ -1,4 +1,5 @@ #include "cuda_fp16.h" +#include "cuda_bf16.h" // Table showing which features are supported on which compute capability // https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index c11e8e22..5d9bddee 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -1,4 +1,3 @@ -#include "cuda_fp16.h" #include "compatibility.cuh" // TODO: This is often used to check that the data is contiguous so that @@ -156,3 +155,19 @@ __device__ __forceinline__ __half expg(__half a) { return hexp(a); } __device__ __forceinline__ __half absg(__half a) { return __habs(a); } __device__ __forceinline__ __half copysigng(__half a, __half b) { return __float2half(copysignf(__half2float(a), __half2float(b))); } #endif + +#if __CUDA_ARCH__ >= 800 +__device__ __forceinline__ __nv_bfloat16 powg(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(powf(__bfloat162float(a), __bfloat162float(b))); } +__device__ __forceinline__ bool isnang(__nv_bfloat16 a) { return __hisnan(a); } +__device__ __forceinline__ __nv_bfloat16 sqrtg(__nv_bfloat16 a) { return hsqrt(a); } +__device__ __forceinline__ __nv_bfloat16 cosg(__nv_bfloat16 a) { return hcos(a); } +__device__ __forceinline__ __nv_bfloat16 sing(__nv_bfloat16 a) { return hsin(a); } +__device__ __forceinline__ __nv_bfloat16 recipg(__nv_bfloat16 a) { __nv_bfloat16 one = 1.0; return one / a; } +__device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); } +__device__ __forceinline__ __nv_bfloat16 tanhg(__nv_bfloat16 a) { return __float2bfloat16(tanhf(__bfloat162float(a))); } +__device__ __forceinline__ __nv_bfloat16 ming(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmin_nan(a, b); } +__device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a); } +__device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a); } +__device__ __forceinline__ __nv_bfloat16 absg(__nv_bfloat16 a) { return __habs(a); } +__device__ __forceinline__ __nv_bfloat16 copysigng(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(copysignf(__bfloat162float(a), __bfloat162float(b))); } +#endif diff --git a/candle-kernels/src/embeddings.cu b/candle-kernels/src/embeddings.cu index 8f27b0c9..18fe5dfb 100644 --- a/candle-kernels/src/embeddings.cu +++ b/candle-kernels/src/embeddings.cu @@ -29,6 +29,10 @@ extern "C" __global__ void FN_NAME( \ } \ } \ +#if __CUDA_ARCH__ >= 800 +EMB_OP(__nv_bfloat16, emb_bf16) +#endif + #if __CUDA_ARCH__ >= 530 EMB_OP(__half, emb_f16) #endif diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index e1ed57ab..c341fcfb 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -43,6 +43,10 @@ extern "C" __global__ void FN_NAME( \ } \ } \ +#if __CUDA_ARCH__ >= 800 +SUM_OP(__nv_bfloat16, sum_bf16) +#endif + #if __CUDA_ARCH__ >= 530 SUM_OP(__half, sum_f16) #endif diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu index 2a20fbec..d08f9e10 100644 --- a/candle-kernels/src/ternary.cu +++ b/candle-kernels/src/ternary.cu @@ -32,6 +32,10 @@ extern "C" __global__ void FN_NAME( \ } \ } \ +#if __CUDA_ARCH__ >= 800 +WHERE_OP(__nv_bfloat16, where_bf16) +#endif + #if __CUDA_ARCH__ >= 530 WHERE_OP(__half, where_f16) #endif diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index 7af5b388..c4df7893 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -40,6 +40,20 @@ __device__ __forceinline__ T relu_fwd(T x) { } +#if __CUDA_ARCH__ >= 800 +UNARY_OP(__nv_bfloat16, ucopy_bf16, x) +UNARY_OP(__nv_bfloat16, uneg_bf16, -x) +UNARY_OP(__nv_bfloat16, uexp_bf16, expg(x)) +UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x)) +UNARY_OP(__nv_bfloat16, usin_bf16, sing(x)) +UNARY_OP(__nv_bfloat16, ucos_bf16, cosg(x)) +UNARY_OP(__nv_bfloat16, uabs_bf16, absg(x)) +UNARY_OP(__nv_bfloat16, usqr_bf16, x*x) +UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x)) +UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x)) +UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x)) +#endif + #if __CUDA_ARCH__ >= 530 UNARY_OP(__half, ucopy_f16, x) UNARY_OP(__half, uneg_f16, -x)