mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add the bf16 cuda kernels.
This commit is contained in:
@ -28,6 +28,10 @@ extern "C" __global__ void FN_NAME( \
|
||||
} \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
AFFINE_OP(__nv_bfloat16, affine_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
AFFINE_OP(__half, affine_f16)
|
||||
#endif
|
||||
|
@ -1,6 +1,13 @@
|
||||
#include "binary_op_macros.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
BINARY_OP(__nv_bfloat16, badd_bf16, x + y)
|
||||
BINARY_OP(__nv_bfloat16, bdiv_bf16, x / y)
|
||||
BINARY_OP(__nv_bfloat16, bmul_bf16, x * y)
|
||||
BINARY_OP(__nv_bfloat16, bsub_bf16, x - y)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
BINARY_OP(__half, badd_f16, x + y)
|
||||
BINARY_OP(__half, bdiv_f16, x / y)
|
||||
|
@ -24,6 +24,19 @@ extern "C" __global__ void FN_NAME( \
|
||||
} \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16)
|
||||
|
||||
CAST_OP(__nv_bfloat16, uint32_t, cast_bf16_u32)
|
||||
CAST_OP(__nv_bfloat16, __half, cast_bf16_f16)
|
||||
CAST_OP(__nv_bfloat16, float, cast_bf16_f32)
|
||||
CAST_OP(__nv_bfloat16, double, cast_bf16_f64)
|
||||
CAST_OP(uint32_t, __nv_bfloat16, cast_u32_bf16)
|
||||
CAST_OP(__half, __nv_bfloat16, cast_f16_bf16)
|
||||
CAST_OP(float, __nv_bfloat16, cast_f32_bf16)
|
||||
CAST_OP(double, __nv_bfloat16, cast_f64_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
CAST_OP(__half, __half, cast_f16_f16)
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include "cuda_fp16.h"
|
||||
#include "cuda_bf16.h"
|
||||
|
||||
// Table showing which features are supported on which compute capability
|
||||
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications
|
||||
|
@ -1,4 +1,3 @@
|
||||
#include "cuda_fp16.h"
|
||||
#include "compatibility.cuh"
|
||||
|
||||
// TODO: This is often used to check that the data is contiguous so that
|
||||
@ -156,3 +155,19 @@ __device__ __forceinline__ __half expg(__half a) { return hexp(a); }
|
||||
__device__ __forceinline__ __half absg(__half a) { return __habs(a); }
|
||||
__device__ __forceinline__ __half copysigng(__half a, __half b) { return __float2half(copysignf(__half2float(a), __half2float(b))); }
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
__device__ __forceinline__ __nv_bfloat16 powg(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(powf(__bfloat162float(a), __bfloat162float(b))); }
|
||||
__device__ __forceinline__ bool isnang(__nv_bfloat16 a) { return __hisnan(a); }
|
||||
__device__ __forceinline__ __nv_bfloat16 sqrtg(__nv_bfloat16 a) { return hsqrt(a); }
|
||||
__device__ __forceinline__ __nv_bfloat16 cosg(__nv_bfloat16 a) { return hcos(a); }
|
||||
__device__ __forceinline__ __nv_bfloat16 sing(__nv_bfloat16 a) { return hsin(a); }
|
||||
__device__ __forceinline__ __nv_bfloat16 recipg(__nv_bfloat16 a) { __nv_bfloat16 one = 1.0; return one / a; }
|
||||
__device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); }
|
||||
__device__ __forceinline__ __nv_bfloat16 tanhg(__nv_bfloat16 a) { return __float2bfloat16(tanhf(__bfloat162float(a))); }
|
||||
__device__ __forceinline__ __nv_bfloat16 ming(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmin_nan(a, b); }
|
||||
__device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a); }
|
||||
__device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a); }
|
||||
__device__ __forceinline__ __nv_bfloat16 absg(__nv_bfloat16 a) { return __habs(a); }
|
||||
__device__ __forceinline__ __nv_bfloat16 copysigng(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(copysignf(__bfloat162float(a), __bfloat162float(b))); }
|
||||
#endif
|
||||
|
@ -29,6 +29,10 @@ extern "C" __global__ void FN_NAME( \
|
||||
} \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
EMB_OP(__nv_bfloat16, emb_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
EMB_OP(__half, emb_f16)
|
||||
#endif
|
||||
|
@ -43,6 +43,10 @@ extern "C" __global__ void FN_NAME( \
|
||||
} \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
SUM_OP(__nv_bfloat16, sum_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
SUM_OP(__half, sum_f16)
|
||||
#endif
|
||||
|
@ -32,6 +32,10 @@ extern "C" __global__ void FN_NAME( \
|
||||
} \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
WHERE_OP(__nv_bfloat16, where_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
WHERE_OP(__half, where_f16)
|
||||
#endif
|
||||
|
@ -40,6 +40,20 @@ __device__ __forceinline__ T relu_fwd(T x) {
|
||||
}
|
||||
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
|
||||
UNARY_OP(__nv_bfloat16, uneg_bf16, -x)
|
||||
UNARY_OP(__nv_bfloat16, uexp_bf16, expg(x))
|
||||
UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x))
|
||||
UNARY_OP(__nv_bfloat16, usin_bf16, sing(x))
|
||||
UNARY_OP(__nv_bfloat16, ucos_bf16, cosg(x))
|
||||
UNARY_OP(__nv_bfloat16, uabs_bf16, absg(x))
|
||||
UNARY_OP(__nv_bfloat16, usqr_bf16, x*x)
|
||||
UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x))
|
||||
UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
|
||||
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
UNARY_OP(__half, ucopy_f16, x)
|
||||
UNARY_OP(__half, uneg_f16, -x)
|
||||
|
Reference in New Issue
Block a user