From 5276755fb39db71074d0305372bd14be58d524c6 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 22 Jun 2023 15:12:59 +0100 Subject: [PATCH] Add cuda support for unary ops. --- examples/cuda_basics.rs | 1 + kernels/src/binary.cu | 18 +++------------- kernels/src/lib.rs | 1 + kernels/src/unary.cu | 26 +++++++++++++++++++++++ src/cuda_backend.rs | 44 +++++++++++++++++++++++++++++++++------ src/dummy_cuda_backend.rs | 8 +++++++ src/storage.rs | 26 +++++++++++++++++++++-- 7 files changed, 101 insertions(+), 23 deletions(-) create mode 100644 kernels/src/unary.cu diff --git a/examples/cuda_basics.rs b/examples/cuda_basics.rs index 52a45999..3db613f6 100644 --- a/examples/cuda_basics.rs +++ b/examples/cuda_basics.rs @@ -8,6 +8,7 @@ fn main() -> Result<()> { let y = Tensor::new(&[2f32, 7., 1., 8., 2.], &device)?; let z = (y + x * 3.)?; println!("{:?}", z.to_vec1::()?); + println!("{:?}", z.sqrt()?.to_vec1::()?); let x = Tensor::ones((3, 2), DType::F32, &device)?; println!("{:?}", x.to_vec2::()?); Ok(()) diff --git a/kernels/src/binary.cu b/kernels/src/binary.cu index bfb97470..df145087 100644 --- a/kernels/src/binary.cu +++ b/kernels/src/binary.cu @@ -2,28 +2,16 @@ #if __CUDA_ARCH__ >= 530 BINARY_OP(__half, badd_f16, x + y) +BINARY_OP(__half, bdiv_f16, x / y) +BINARY_OP(__half, bmul_f16, x * y) +BINARY_OP(__half, bsub_f16, x - y) #endif BINARY_OP(float, badd_f32, x + y) BINARY_OP(double, badd_fwd_f64, x + y); - -#if __CUDA_ARCH__ >= 530 -BINARY_OP(__half, bdiv_f16, x / y) -#endif - BINARY_OP(float, bdiv_f32, x / y) BINARY_OP(double, bdiv_f64, x / y); - -#if __CUDA_ARCH__ >= 530 -BINARY_OP(__half, bmul_f16, x * y) -#endif - BINARY_OP(float, bmul_f32, x * y) BINARY_OP(double, bmul_f64, x * y); - -#if __CUDA_ARCH__ >= 530 -BINARY_OP(__half, bsub_f16, x - y) -#endif - BINARY_OP(float, bsub_f32, x - y) BINARY_OP(double, bsub_f64, x - y); diff --git a/kernels/src/lib.rs b/kernels/src/lib.rs index 3702964c..8e0d9eb9 100644 --- a/kernels/src/lib.rs +++ b/kernels/src/lib.rs @@ -1,3 +1,4 @@ pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx")); pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx")); pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); +pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx")); diff --git a/kernels/src/unary.cu b/kernels/src/unary.cu new file mode 100644 index 00000000..53531933 --- /dev/null +++ b/kernels/src/unary.cu @@ -0,0 +1,26 @@ +#include "cuda_utils.cuh" + +#define UNARY_OP(TYPENAME, FN_NAME, FUNC) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const TYPENAME *inp, \ + TYPENAME *out \ +) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + TYPENAME x = inp ? inp[i] : out[i]; \ + out[i] = FUNC; \ + } \ +} \ + +#if __CUDA_ARCH__ >= 530 +UNARY_OP(__half, uneg_f16, -x) +UNARY_OP(__half, usqr_f16, x*x) +UNARY_OP(__half, usqrt_f16, sqrtg(x)) +#endif + +UNARY_OP(float, uneg_f32, -x) +UNARY_OP(float, uneg_f64, -x) +UNARY_OP(float, usqr_f32, x*x) +UNARY_OP(float, usqr_f64, x*x) +UNARY_OP(float, usqrt_f32, sqrtg(x)) +UNARY_OP(float, usqrt_f64, sqrtg(x)) diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 12215ddc..ce0e803d 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -145,21 +145,53 @@ impl CudaStorage { match self { Self::F32(arg) => { let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?; - // SAFETY: if this function returns Ok(..), the kernel has been applied - // and has set the initially unset memory. + // SAFETY: Set later by running the kernel. let out = unsafe { dev.0.alloc::(elem_count) }?; let params = (elem_count, arg, &out, mul as f32, add as f32); - // SAFETY: well, well, well... + // SAFETY: ffi. unsafe { func.launch(cfg, params) }?; Ok(Self::F32(out)) } Self::F64(arg) => { let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?; - // SAFETY: if this function returns Ok(..), the kernel has been applied - // and has set the initially unset memory. + // SAFETY: Set later by running the kernel. let out = unsafe { dev.0.alloc::(elem_count) }?; let params = (elem_count, arg, &out, mul, add); - // SAFETY: well, well, well... + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + Ok(Self::F64(out)) + } + } + } + + pub(crate) fn unary_impl( + &self, + shape: &Shape, + stride: &[usize], + ) -> Result { + if !shape.is_contiguous(stride) { + return Err(CudaError::RequiresContiguous { op: "affine" }); + } + + let elem_count = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); + let dev = self.device(); + match self { + Self::F32(arg) => { + let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.0.alloc::(elem_count) }?; + let params = (elem_count, arg, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + Ok(Self::F32(out)) + } + Self::F64(arg) => { + let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.0.alloc::(elem_count) }?; + let params = (elem_count, arg, &out); + // SAFETY: ffi. unsafe { func.launch(cfg, params) }?; Ok(Self::F64(out)) } diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs index 512f7b8f..d5e0ae63 100644 --- a/src/dummy_cuda_backend.rs +++ b/src/dummy_cuda_backend.rs @@ -54,6 +54,14 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + pub(crate) fn unary_impl( + &self, + _: &Shape, + _: &[usize], + ) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + pub(crate) fn binary_impl( &self, _: &Self, diff --git a/src/storage.rs b/src/storage.rs index 4c74ffd5..e96f4706 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -8,12 +8,18 @@ pub enum Storage { pub(crate) trait UnaryOp { const NAME: &'static str; + // TODO: These kernels are compatible with arbitrary strides. We should also consider the + // contiguous case separately as it's easy to optimize things out there. + const KERNEL_F32: &'static str; + const KERNEL_F64: &'static str; fn f32(v1: f32) -> f32; fn f64(v1: f64) -> f64; } pub(crate) trait BinaryOp { const NAME: &'static str; + // TODO: These kernels are compatible with arbitrary strides. We should also consider the + // contiguous case separately as it's easy to optimize things out there. const KERNEL_F32: &'static str; const KERNEL_F64: &'static str; fn f32(v1: f32, v2: f32) -> f32; @@ -84,6 +90,8 @@ impl UnaryOp for Neg { fn f64(v1: f64) -> f64 { -v1 } + const KERNEL_F32: &'static str = "uneg_f32"; + const KERNEL_F64: &'static str = "uneg_f64"; } impl UnaryOp for Sqr { @@ -94,6 +102,8 @@ impl UnaryOp for Sqr { fn f64(v1: f64) -> f64 { v1 * v1 } + const KERNEL_F32: &'static str = "usqr_f32"; + const KERNEL_F64: &'static str = "usqr_f64"; } impl UnaryOp for Sqrt { @@ -104,6 +114,8 @@ impl UnaryOp for Sqrt { fn f64(v1: f64) -> f64 { v1.sqrt() } + const KERNEL_F32: &'static str = "usqrt_f32"; + const KERNEL_F64: &'static str = "usqrt_f64"; } impl Storage { @@ -168,7 +180,10 @@ impl Storage { let storage = storage.unary_impl::(shape, stride)?; Ok(Self::Cpu(storage)) } - Self::Cuda { .. } => todo!(), + Self::Cuda(storage) => { + let storage = storage.unary_impl::(shape, stride)?; + Ok(Self::Cuda(storage)) + } } } @@ -269,7 +284,14 @@ impl Storage { let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?; Ok(Self::Cpu(storage)) } - _ => todo!(), + (Self::Cuda(_), Self::Cuda(_)) => { + todo!() + } + (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "matmul", + }), } } }