From cd2a171c0614b5ab1daf2f01f6f5532c34617076 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 13:25:02 +0100 Subject: [PATCH] Add the where kernels. --- kernels/src/cuda_utils.cuh | 4 ++++ kernels/src/lib.rs | 1 + kernels/src/ternary.cu | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+) create mode 100644 kernels/src/ternary.cu diff --git a/kernels/src/cuda_utils.cuh b/kernels/src/cuda_utils.cuh index a7c2a9f6..c11e8e22 100644 --- a/kernels/src/cuda_utils.cuh +++ b/kernels/src/cuda_utils.cuh @@ -1,6 +1,10 @@ #include "cuda_fp16.h" #include "compatibility.cuh" +// TODO: This is often used to check that the data is contiguous so that +// kernels can be easily mapped. However this only returns true for row +// major, if all the inputs are column major, we could apply the fast path +// too (but we wouldn't if some of them are row major and some column major). __device__ bool is_contiguous( const size_t num_dims, const size_t *dims, diff --git a/kernels/src/lib.rs b/kernels/src/lib.rs index f199215c..d29022da 100644 --- a/kernels/src/lib.rs +++ b/kernels/src/lib.rs @@ -3,4 +3,5 @@ pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx")); pub const EMBEDDINGS: &str = include_str!(concat!(env!("OUT_DIR"), "/embeddings.ptx")); pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx")); +pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx")); pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx")); diff --git a/kernels/src/ternary.cu b/kernels/src/ternary.cu new file mode 100644 index 00000000..8f8b3ac5 --- /dev/null +++ b/kernels/src/ternary.cu @@ -0,0 +1,37 @@ +#include "cuda_utils.cuh" +#include + +#define WHERE_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const uint32_t *ids, \ + const TYPENAME *t, \ + const TYPENAME *f, \ + TYPENAME *out \ +) { \ + const size_t *dims = info; \ + const size_t *strides = info + num_dims; \ + const size_t *strides_t = info + 2*num_dims; \ + const size_t *strides_f = info + 2*num_dims; \ + if (is_contiguous(num_dims, dims, strides) \ + && is_contiguous(num_dims, dims, strides_f) \ + && is_contiguous(num_dims, dims, strides_t)) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + out[i] = ids[i] ? t[i] : f[i]; \ + } \ + } \ + else { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ + unsigned strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \ + unsigned strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \ + out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \ + } \ + } \ +} \ + +WHERE_OP(float, where_f32) +WHERE_OP(double, where_f64) +WHERE_OP(uint32_t, where_u32)