diff --git a/kernels/src/binary.cu b/kernels/src/binary.cu new file mode 100644 index 00000000..bfb97470 --- /dev/null +++ b/kernels/src/binary.cu @@ -0,0 +1,29 @@ +#include "binary_op_macros.cuh" + +#if __CUDA_ARCH__ >= 530 +BINARY_OP(__half, badd_f16, x + y) +#endif + +BINARY_OP(float, badd_f32, x + y) +BINARY_OP(double, badd_fwd_f64, x + y); + +#if __CUDA_ARCH__ >= 530 +BINARY_OP(__half, bdiv_f16, x / y) +#endif + +BINARY_OP(float, bdiv_f32, x / y) +BINARY_OP(double, bdiv_f64, x / y); + +#if __CUDA_ARCH__ >= 530 +BINARY_OP(__half, bmul_f16, x * y) +#endif + +BINARY_OP(float, bmul_f32, x * y) +BINARY_OP(double, bmul_f64, x * y); + +#if __CUDA_ARCH__ >= 530 +BINARY_OP(__half, bsub_f16, x - y) +#endif + +BINARY_OP(float, bsub_f32, x - y) +BINARY_OP(double, bsub_f64, x - y); diff --git a/kernels/src/binary_add.cu b/kernels/src/binary_add.cu deleted file mode 100644 index 979a429e..00000000 --- a/kernels/src/binary_add.cu +++ /dev/null @@ -1,8 +0,0 @@ -#include "binary_op_macros.cuh" - -#if __CUDA_ARCH__ >= 530 -BINARY_OP(__half, badd_f16, x + y) -#endif - -BINARY_OP(float, badd_f32, x + y) -BINARY_OP(double, badd_fwd_f64, x + y); diff --git a/kernels/src/binary_div.cu b/kernels/src/binary_div.cu deleted file mode 100644 index 410eed62..00000000 --- a/kernels/src/binary_div.cu +++ /dev/null @@ -1,8 +0,0 @@ -#include "binary_op_macros.cuh" - -#if __CUDA_ARCH__ >= 530 -BINARY_OP(__half, bdiv_f16, x / y) -#endif - -BINARY_OP(float, bdiv_f32, x / y) -BINARY_OP(double, bdiv_f64, x / y); diff --git a/kernels/src/binary_mul.cu b/kernels/src/binary_mul.cu deleted file mode 100644 index 8adab74b..00000000 --- a/kernels/src/binary_mul.cu +++ /dev/null @@ -1,8 +0,0 @@ -#include "binary_op_macros.cuh" - -#if __CUDA_ARCH__ >= 530 -BINARY_OP(__half, bmul_f16, x * y) -#endif - -BINARY_OP(float, bmul_f32, x * y) -BINARY_OP(double, bmul_f64, x * y); diff --git a/kernels/src/binary_sub.cu b/kernels/src/binary_sub.cu deleted file mode 100644 index 891932db..00000000 --- a/kernels/src/binary_sub.cu +++ /dev/null @@ -1,8 +0,0 @@ -#include "binary_op_macros.cuh" - -#if __CUDA_ARCH__ >= 530 -BINARY_OP(__half, bsub_f16, x - y) -#endif - -BINARY_OP(float, bsub_f32, x - y) -BINARY_OP(double, bsub_f64, x - y); diff --git a/kernels/src/lib.rs b/kernels/src/lib.rs index 5ebea218..3702964c 100644 --- a/kernels/src/lib.rs +++ b/kernels/src/lib.rs @@ -1,6 +1,3 @@ pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx")); -pub const BINARY_ADD: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_add.ptx")); -pub const BINARY_DIV: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_div.ptx")); -pub const BINARY_MUL: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_mul.ptx")); -pub const BINARY_SUB: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_sub.ptx")); +pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx")); pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index de70c835..12215ddc 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -166,7 +166,7 @@ impl CudaStorage { } } - pub(crate) fn add_impl( + pub(crate) fn binary_impl( &self, rhs: &Self, shape: &Shape, @@ -180,8 +180,8 @@ impl CudaStorage { let dims_and_strides = [dims, lhs_stride, rhs_stride].concat(); match (self, rhs) { (Self::F32(lhs), Self::F32(rhs)) => { - let func = dev.get_or_load_func("badd_f32", kernels::BINARY_ADD)?; - // SAFETY: Set later by running the add kernel. + let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?; + // SAFETY: Set later by running the kernel. let out = unsafe { dev.0.alloc::(elem_count) }?; let dims_and_strides = dev.0.htod_copy(dims_and_strides)?; let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); @@ -190,8 +190,8 @@ impl CudaStorage { Ok(Self::F32(out)) } (Self::F64(lhs), Self::F64(rhs)) => { - // SAFETY: Set later by running the add kernel. - let func = dev.get_or_load_func("badd_f64", kernels::BINARY_ADD)?; + // SAFETY: Set later by running the kernel. + let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?; let out = unsafe { dev.0.alloc::(elem_count) }?; let dims_and_strides = dev.0.htod_copy(dims_and_strides)?; let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); @@ -200,7 +200,7 @@ impl CudaStorage { Ok(Self::F64(out)) } // The dtypes should have been checked at this point so this is an internal error. - _ => Err(CudaError::InternalError("dtype mismatch in add")), + _ => Err(CudaError::InternalError("dtype mismatch in binary op")), } } diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs index 16e78fbe..512f7b8f 100644 --- a/src/dummy_cuda_backend.rs +++ b/src/dummy_cuda_backend.rs @@ -54,7 +54,13 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn add_impl(&self, _: &Self, _: &Shape, _: &[usize], _: &[usize]) -> Result { + pub(crate) fn binary_impl( + &self, + _: &Self, + _: &Shape, + _: &[usize], + _: &[usize], + ) -> Result { Err(Error::NotCompiledWithCudaSupport) } } diff --git a/src/storage.rs b/src/storage.rs index 22f9a26c..4c74ffd5 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -14,15 +14,10 @@ pub(crate) trait UnaryOp { pub(crate) trait BinaryOp { const NAME: &'static str; + const KERNEL_F32: &'static str; + const KERNEL_F64: &'static str; fn f32(v1: f32, v2: f32) -> f32; fn f64(v1: f64, v2: f64) -> f64; - fn cuda_impl( - lhs: &CudaStorage, - rhs: &CudaStorage, - shape: &Shape, - lhs_stride: &[usize], - rhs_stride: &[usize], - ) -> Result; } struct Add; @@ -35,78 +30,50 @@ struct Sqrt; impl BinaryOp for Add { const NAME: &'static str = "add"; + const KERNEL_F32: &'static str = "badd_f32"; + const KERNEL_F64: &'static str = "badd_f64"; fn f32(v1: f32, v2: f32) -> f32 { v1 + v2 } fn f64(v1: f64, v2: f64) -> f64 { v1 + v2 } - fn cuda_impl( - lhs: &CudaStorage, - rhs: &CudaStorage, - shape: &Shape, - lhs_stride: &[usize], - rhs_stride: &[usize], - ) -> Result { - Ok(lhs.add_impl(rhs, shape, lhs_stride, rhs_stride)?) - } } impl BinaryOp for Sub { const NAME: &'static str = "sub"; + const KERNEL_F32: &'static str = "bsub_f32"; + const KERNEL_F64: &'static str = "bsub_f64"; fn f32(v1: f32, v2: f32) -> f32 { v1 - v2 } fn f64(v1: f64, v2: f64) -> f64 { v1 - v2 } - fn cuda_impl( - _: &CudaStorage, - _: &CudaStorage, - _: &Shape, - _: &[usize], - _: &[usize], - ) -> Result { - todo!() - } } impl BinaryOp for Mul { const NAME: &'static str = "mul"; + const KERNEL_F32: &'static str = "bmul_f32"; + const KERNEL_F64: &'static str = "bmul_f64"; fn f32(v1: f32, v2: f32) -> f32 { v1 * v2 } fn f64(v1: f64, v2: f64) -> f64 { v1 * v2 } - fn cuda_impl( - _: &CudaStorage, - _: &CudaStorage, - _: &Shape, - _: &[usize], - _: &[usize], - ) -> Result { - todo!() - } } impl BinaryOp for Div { const NAME: &'static str = "div"; + const KERNEL_F32: &'static str = "bdiv_f32"; + const KERNEL_F64: &'static str = "bdiv_f64"; fn f32(v1: f32, v2: f32) -> f32 { v1 / v2 } fn f64(v1: f64, v2: f64) -> f64 { v1 / v2 } - fn cuda_impl( - _: &CudaStorage, - _: &CudaStorage, - _: &Shape, - _: &[usize], - _: &[usize], - ) -> Result { - todo!() - } } impl UnaryOp for Neg { @@ -221,7 +188,7 @@ impl Storage { Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = B::cuda_impl(lhs, rhs, shape, lhs_stride, rhs_stride)?; + let storage = lhs.binary_impl::(rhs, shape, lhs_stride, rhs_stride)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => {