diff --git a/kernels/README.md b/kernels/README.md new file mode 100644 index 00000000..1043f31f --- /dev/null +++ b/kernels/README.md @@ -0,0 +1,4 @@ +# candle-kernels + +This crate contains CUDA kernels used from candle. Some of these implementations +come from the [dfdx crate](https://github.com/coreylowman/dfdx). diff --git a/kernels/src/binary_add.cu b/kernels/src/binary_add.cu new file mode 100644 index 00000000..16156bc2 --- /dev/null +++ b/kernels/src/binary_add.cu @@ -0,0 +1,20 @@ +#include "binary_op_macros.cuh" + +struct BinaryAddOp {}; + +#if __CUDA_ARCH__ >= 530 +BINARY_OP(__half, badd_fwd_f16, badd_bwd_lhs_f16, badd_bwd_rhs_f16, BinaryAddOp, + x + y, + 1.0, + 1.0) +#endif + +BINARY_OP(float, badd_fwd_f32, badd_bwd_lhs_f32, badd_bwd_rhs_f32, BinaryAddOp, + x + y, + 1.0, + 1.0) + +BINARY_OP(double, badd_fwd_f64, badd_bwd_lhs_f64, badd_bwd_rhs_f64, BinaryAddOp, + x + y, + 1.0, + 1.0) diff --git a/kernels/src/binary_div.cu b/kernels/src/binary_div.cu new file mode 100644 index 00000000..a36c58b8 --- /dev/null +++ b/kernels/src/binary_div.cu @@ -0,0 +1,21 @@ +#include "binary_op_macros.cuh" + +struct BinaryDivOp {}; + +#if __CUDA_ARCH__ >= 530 +BINARY_OP(__half, bdiv_fwd_f16, bdiv_bwd_lhs_f16, bdiv_bwd_rhs_f16, BinaryDivOp, + x / y, + recipg(y), + -x / (y * y)) +#endif + +BINARY_OP(float, bdiv_fwd_f32, bdiv_bwd_lhs_f32, bdiv_bwd_rhs_f32, BinaryDivOp, + x / y, + recipg(y), + -x / (y * y)) + +BINARY_OP(double, bdiv_fwd_f64, bdiv_bwd_lhs_f64, bdiv_bwd_rhs_f64, BinaryDivOp, + x / y, + recipg(y), + -x / (y * y)) + diff --git a/kernels/src/binary_mul.cu b/kernels/src/binary_mul.cu new file mode 100644 index 00000000..9a6d1430 --- /dev/null +++ b/kernels/src/binary_mul.cu @@ -0,0 +1,21 @@ +#include "binary_op_macros.cuh" + +struct BinaryMulKernalOp {}; + +#if __CUDA_ARCH__ >= 530 +BINARY_OP(__half, bmul_fwd_f16, bmul_bwd_lhs_f16, bmul_bwd_rhs_f16, BinaryMulKernalOp, + x * y, + y, + x) +#endif + +BINARY_OP(float, bmul_fwd_f32, bmul_bwd_lhs_f32, bmul_bwd_rhs_f32, BinaryMulKernalOp, + x * y, + y, + x) + +BINARY_OP(double, bmul_fwd_f64, bmul_bwd_lhs_f64, bmul_bwd_rhs_f64, BinaryMulKernalOp, + x * y, + y, + x) + diff --git a/kernels/src/binary_op_macros.cuh b/kernels/src/binary_op_macros.cuh new file mode 100644 index 00000000..b79a4d81 --- /dev/null +++ b/kernels/src/binary_op_macros.cuh @@ -0,0 +1,101 @@ +#include "cuda_utils.cuh" + +#define LONG_BINARY_OP(TYPENAME, FORWARD, BACKWARD_LHS, BACKWARD_RHS, OP_STRUCT, FUNC, DFDX, DFDY) \ +extern "C" __global__ void FORWARD( \ + const OP_STRUCT op, \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const TYPENAME *lhs, \ + const TYPENAME *rhs, \ + TYPENAME *out \ +) { \ + const size_t *dims = info; \ + const size_t *lhs_strides = info + num_dims; \ + const size_t *rhs_strides = info + 2 * num_dims; \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int tmp_i = i; \ + unsigned int lhs_i = 0; \ + unsigned int rhs_i = 0; \ + for (int d = num_dims - 1; d >= 0; d--) { \ + unsigned int i_dim = tmp_i % dims[d]; \ + lhs_i += i_dim * lhs_strides[d]; \ + rhs_i += i_dim * rhs_strides[d]; \ + tmp_i /= dims[d]; \ + } \ + TYPENAME x = lhs ? lhs[lhs_i] : out[i]; \ + TYPENAME y = rhs ? rhs[rhs_i] : out[i]; \ + TYPENAME fx; \ + FUNC\ + out[i] = fx; \ + } \ +} \ +\ +extern "C" __global__ void BACKWARD_LHS( \ + const OP_STRUCT op, \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const TYPENAME *lhs, \ + TYPENAME *grad_lhs, \ + const size_t chunk_len, \ + const TYPENAME *rhs, \ + const TYPENAME *grad_out \ +) { \ + const size_t *dims = info + 0 * num_dims; \ + const size_t *out_strides = info + 1 * num_dims; \ + const size_t *rhs_strides = info + 2 * num_dims; \ + TYPENAME zero = 0.0; \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int tmp_i = i; \ + unsigned int out_i = 0; \ + unsigned int rhs_i = 0; \ + for (int d = num_dims - 1; d >= 0; d--) { \ + unsigned int i_dim = tmp_i % dims[d]; \ + out_i += i_dim * out_strides[d]; \ + rhs_i += i_dim * rhs_strides[d]; \ + tmp_i /= dims[d]; \ + } \ + TYPENAME x = lhs ? lhs[i / chunk_len] : zero; \ + TYPENAME y = rhs ? rhs[rhs_i] : zero; \ + TYPENAME go = grad_out[out_i]; \ + TYPENAME dfdx = (DFDX); \ + chunk_sum(chunk_len, dfdx * go, grad_lhs); \ + } \ +} \ +\ +extern "C" __global__ void BACKWARD_RHS( \ + const OP_STRUCT op, \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const TYPENAME *lhs, \ + const TYPENAME *rhs, \ + TYPENAME *grad_rhs, \ + const size_t chunk_len, \ + const TYPENAME *grad_out \ +) { \ + const size_t *dims = info + 3 * num_dims; \ + const size_t *out_strides = info + 4 * num_dims; \ + const size_t *lhs_strides = info + 5 * num_dims; \ + TYPENAME zero = 0.0; \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int tmp_i = i; \ + unsigned int lhs_i = 0; \ + unsigned int out_i = 0; \ + for (int d = num_dims - 1; d >= 0; d--) { \ + unsigned int i_dim = tmp_i % dims[d]; \ + lhs_i += i_dim * lhs_strides[d]; \ + out_i += i_dim * out_strides[d]; \ + tmp_i /= dims[d]; \ + } \ + TYPENAME x = lhs ? lhs[lhs_i] : zero; \ + TYPENAME y = rhs ? rhs[i / chunk_len] : zero; \ + TYPENAME go = grad_out[out_i]; \ + TYPENAME dfdy = (DFDY); \ + chunk_sum(chunk_len, dfdy * go, grad_rhs); \ + } \ +} + +#define BINARY_OP(TYPENAME, FORWARD, BACKWARD_LHS, BACKWARD_RHS, OP_STRUCT, FUNC, DFDX, DFDY) \ + LONG_BINARY_OP(TYPENAME, FORWARD, BACKWARD_LHS, BACKWARD_RHS, OP_STRUCT, fx = (FUNC);, DFDX, DFDY) diff --git a/kernels/src/binary_sub.cu b/kernels/src/binary_sub.cu new file mode 100644 index 00000000..f4c93972 --- /dev/null +++ b/kernels/src/binary_sub.cu @@ -0,0 +1,21 @@ +#include "binary_op_macros.cuh" + +struct BinarySubKernelOp {}; + +#if __CUDA_ARCH__ >= 530 +BINARY_OP(__half, bsub_fwd_f16, bsub_bwd_lhs_f16, bsub_bwd_rhs_f16, BinarySubKernelOp, + x - y, + 1.0, + -1.0) +#endif + +BINARY_OP(float, bsub_fwd_f32, bsub_bwd_lhs_f32, bsub_bwd_rhs_f32, BinarySubKernelOp, + x - y, + 1.0, + -1.0) + +BINARY_OP(double, bsub_fwd_f64, bsub_bwd_lhs_f64, bsub_bwd_rhs_f64, BinarySubKernelOp, + x - y, + 1.0, + -1.0) + diff --git a/kernels/src/compatibility.cuh b/kernels/src/compatibility.cuh new file mode 100644 index 00000000..69e34764 --- /dev/null +++ b/kernels/src/compatibility.cuh @@ -0,0 +1,171 @@ +#include "cuda_fp16.h" + +// Table showing which features are supported on which compute capability +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications + +// FIXME: the minimum compute capabilities are just guesses since the table is not specific enough + +// #if __CUDA_ARCH__ < 600 +// __device__ __forceinline__ __half __hmax(__half a, __half b) { +// return __float2half(fmaxf(__half2float(a), __half2float(b))); +// } +// __device__ __forceinline__ __half __hmin(__half a, __half b) { +// return __float2half(fminf(__half2float(a), __half2float(b))); +// } +// #endif + +#if __CUDA_ARCH__ < 800 +__device__ __forceinline__ __half __hmax_nan(__half a, __half b) { + // return __hisnan(a) ? a : (__hisnan(b) ? b : __hmax(a, b)); +} +__device__ __forceinline__ __half __hmin_nan(__half a, __half b) { + // return __hisnan(a) ? a : (__hisnan(b) ? b : __hmin(a, b)); +} +#endif + +#if __CUDA_ARCH__ < 600 +// Copied from https://docs.nvidia.com/cuda/cuda-c-programming-guide/#atomic-functions +__device__ double atomicAdd(double* address, double val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; + + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val + + __longlong_as_double(assumed))); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __longlong_as_double(old); +} +#endif + + +#if __CUDA_ARCH__ < 700 +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomicadd +// The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher. +// Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119 +__device__ __half atomicAdd(__half *address, __half val) { + // unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + // unsigned int old = *address_as_ui; + // unsigned int assumed; + // bool unaligned = (size_t) address & 2; + // do { + // assumed = old; + // unsigned int hsum; + // hsum = unaligned ? (old >> 16) : (old & 0xffff); + // hsum = __half_as_ushort(__ushort_as_half(hsum) + val); + // old = atomicCAS(address_as_ui, assumed, + // unaligned ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum + // ); + + // } while (assumed != old); + // return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff)); +} +#endif + + +__device__ __forceinline__ __half atomicMaxf(__half* address, __half val) { +#if __CUDA_ARCH__ < 700 + // On older GPUs we do not have access to atomicCAS for shorts, so we have to do some trickery. + // Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119 + unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + bool unaligned = (size_t) address & 2; + do { + assumed = old; + unsigned int hmax; + hmax = unaligned ? (old >> 16) : (old & 0xffff); + hmax = __half_as_ushort(__hmax_nan(val, __ushort_as_half(hmax))); + old = atomicCAS(address_as_ui, assumed, + unaligned ? (old & 0xffff) | (hmax << 16) : (old & 0xffff0000) | hmax + ); + + } while (assumed != old); + return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff)); +#else + // Based on https://docs.nvidia.com/cuda/cuda-c-programming-guide/#atomic-functions + unsigned short int* casted_address = (unsigned short int*)address; + unsigned short int old = *casted_address; + unsigned short int assumed; + do { + assumed = old; + old = atomicCAS(casted_address, assumed, __half_as_ushort(__hmax_nan(val, __ushort_as_half(assumed)))); + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + return __ushort_as_half(old); +#endif +} + +// atomicMax is not implemented for floats, +// solution copied https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda +__device__ __forceinline__ float atomicMaxf(float * addr, float value) { + if (signbit(value)) { + return __uint_as_float(atomicMin((unsigned int *)addr, __float_as_uint(value))); + } else { + return __int_as_float(atomicMax((int *)addr, __float_as_int(value))); + } +} + +__device__ __forceinline__ double atomicMaxf(double * addr, double value) { + if (signbit(value)) { + return __longlong_as_double(atomicMin((unsigned long long int *)addr, __double_as_longlong(value))); + } else { + return __longlong_as_double(atomicMax((long long int *)addr, __double_as_longlong(value))); + } +} + + +__device__ __forceinline__ __half atomicMinf(__half* address, __half val) { +#if __CUDA_ARCH__ < 700 + // On older GPUs we do not have access to atomicCAS for shorts, so we have to do some trickery. + // Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119 + unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + bool unaligned = (size_t) address & 2; + do { + assumed = old; + unsigned int hmin; + hmin = unaligned ? (old >> 16) : (old & 0xffff); + hmin = __half_as_ushort(__hmin_nan(val, __ushort_as_half(hmin))); + old = atomicCAS(address_as_ui, assumed, + unaligned ? (old & 0xffff) | (hmin << 16) : (old & 0xffff0000) | hmin + ); + + } while (assumed != old); + return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff)); +#else + // Based on https://docs.nvidia.com/cuda/cuda-c-programming-guide/#atomic-functions + unsigned short int* casted_address = (unsigned short int*)address; + unsigned short int old = *casted_address; + unsigned short int assumed; + do { + assumed = old; + old = atomicCAS(casted_address, assumed, __half_as_ushort(__hmin_nan(val, __ushort_as_half(assumed)))); + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + return __ushort_as_half(old); +#endif +} + +// atomicMin is not implemented for floats, +// solution copied https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda +__device__ __forceinline__ float atomicMinf(float * addr, float value) { + if (signbit(value)) { + return __uint_as_float(atomicMax((unsigned int *)addr, __float_as_uint(value))); + } else { + return __int_as_float(atomicMin((int *)addr, __float_as_int(value))); + } +} + +__device__ __forceinline__ double atomicMinf(double * addr, double value) { + if (signbit(value)) { + return __longlong_as_double(atomicMax((unsigned long long int *)addr, __double_as_longlong(value))); + } else { + return __longlong_as_double(atomicMin((long long int *)addr, __double_as_longlong(value))); + } +} diff --git a/kernels/src/cuda_utils.cuh b/kernels/src/cuda_utils.cuh new file mode 100644 index 00000000..580ed11a --- /dev/null +++ b/kernels/src/cuda_utils.cuh @@ -0,0 +1,138 @@ +#include "cuda_fp16.h" +#include "compatibility.cuh" + +__device__ unsigned int get_strided_index( + unsigned int idx, + const size_t num_dims, + const size_t *dims, + const size_t *strides +) { + unsigned int strided_i = 0; + for (unsigned int d = 0; d < num_dims; d++) { + unsigned int dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +__device__ unsigned int restrided( + const unsigned int strided_i, + const size_t num_dims, + const size_t *dims, + const size_t *strides, + const size_t *new_strides +) { + unsigned int idx = 0; + for (int d = 0; d < num_dims; d++) { + idx += (strides[d] == 0 ? 0 : (strided_i / strides[d]) % dims[d]) * new_strides[d]; + } + return idx; +} + +// Sourced from https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 +// Input must be less than or equal to 2 ^ 16 +// used in reductions +__device__ __forceinline__ unsigned int next_power_of_two(unsigned int v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v++; + return v; +} + +// Efficiently computes the sum of each chunk in "data" of size chunk_len, and +// stores the sums in out[i / chunk_len] +template +__device__ void chunk_sum( + const size_t chunk_len, + const T data, + T* out +) { + __shared__ T buf[1024]; + + // assumes that threads where i >= numel have already exited + unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int block_i = threadIdx.x; + + // Fall back to atomicAdd if chunk_len is small to reduce overhead + if (chunk_len <= 2) { + atomicAdd(out + i / chunk_len, data); + return; + } + buf[block_i] = data; + + unsigned int chunk_i = i % chunk_len; + unsigned int chunk_start = max((int)(block_i - chunk_i), 0); + unsigned int chunk_end = min((unsigned int)(block_i + chunk_len - chunk_i), blockDim.x); + + chunk_i = block_i - chunk_start; + + size_t max_chunk_len = min(chunk_end - chunk_start, blockDim.x); + size_t incr = next_power_of_two(max_chunk_len) >> 1; + + __syncthreads(); + + // Uses sequential addressing as discussed in + // https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf + for (; incr > 0; incr >>= 1) { + unsigned int block_i_2 = block_i + incr; + + if (block_i_2 < chunk_end && chunk_i < incr) { + // This is sound because __syncthreads and the conditions above + // ensure that no data races occur + buf[block_i] += buf[block_i_2]; + } + + __syncthreads(); + } + + if (block_i == chunk_start) { + atomicAdd(out + i / chunk_len, buf[block_i]); + } +} + +__device__ __forceinline__ bool isnang(float a) { return isnan(a); } +__device__ __forceinline__ bool isnang(double a) { return isnan(a); } +__device__ __forceinline__ float recipg(float a) { return 1.0 / a; } +__device__ __forceinline__ double recipg(double a) { return 1.0 / a; } +__device__ __forceinline__ float cosg(float a) { return cosf(a); } +__device__ __forceinline__ double cosg(double a) { return cos(a); } +__device__ __forceinline__ float sing(float a) { return sinf(a); } +__device__ __forceinline__ double sing(double a) { return sin(a); } +__device__ __forceinline__ float sqrtg(float a) { return sqrtf(a); } +__device__ __forceinline__ double sqrtg(double a) { return sqrt(a); } +__device__ __forceinline__ float powg(float a, float b) { return powf(a, b); } +__device__ __forceinline__ double powg(double a, double b) { return pow(a, b); } +__device__ __forceinline__ float tanhg(float a) { return tanhf(a); } +__device__ __forceinline__ double tanhg(double a) { return tanh(a); } +__device__ __forceinline__ float maxg(float a, float b) { return fmaxf(a, b); } +__device__ __forceinline__ double maxg(double a, double b) { return fmax(a, b); } +__device__ __forceinline__ float ming(float a, float b) { return fminf(a, b); } +__device__ __forceinline__ double ming(double a, double b) { return fmin(a, b); } +__device__ __forceinline__ float logg(float a) { return logf(a); } +__device__ __forceinline__ double logg(double a) { return log(a); } +__device__ __forceinline__ float expg(float a) { return expf(a); } +__device__ __forceinline__ double expg(double a) { return exp(a); } +__device__ __forceinline__ float absg(float a) { return fabsf(a); } +__device__ __forceinline__ double absg(double a) { return fabs(a); } +__device__ __forceinline__ float copysigng(float a, float b) { return copysignf(a, b); } +__device__ __forceinline__ double copysigng(double a, double b) { return copysign(a, b); } + +#if __CUDA_ARCH__ >= 530 +__device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); } +__device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); } +__device__ __forceinline__ __half sqrtg(__half a) { return hsqrt(a); } +__device__ __forceinline__ __half cosg(__half a) { return hcos(a); } +__device__ __forceinline__ __half sing(__half a) { return hsin(a); } +__device__ __forceinline__ __half recipg(__half a) { __half one = 1.0; return one / a; } +__device__ __forceinline__ __half maxg(__half a, __half b) { return __hmax_nan(a, b); } +__device__ __forceinline__ __half tanhg(__half a) { return __float2half(tanhf(__half2float(a))); } +__device__ __forceinline__ __half ming(__half a, __half b) { return __hmin_nan(a, b); } +__device__ __forceinline__ __half logg(__half a) { return hlog(a); } +__device__ __forceinline__ __half expg(__half a) { return hexp(a); } +__device__ __forceinline__ __half absg(__half a) { return __habs(a); } +__device__ __forceinline__ __half copysigng(__half a, __half b) { return __float2half(copysignf(__half2float(a), __half2float(b))); } +#endif diff --git a/kernels/src/lib.rs b/kernels/src/lib.rs index bb599ac4..5ebea218 100644 --- a/kernels/src/lib.rs +++ b/kernels/src/lib.rs @@ -1,2 +1,6 @@ pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx")); +pub const BINARY_ADD: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_add.ptx")); +pub const BINARY_DIV: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_div.ptx")); +pub const BINARY_MUL: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_mul.ptx")); +pub const BINARY_SUB: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_sub.ptx")); pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));