diff --git a/kernels/src/binary.cu b/kernels/src/binary.cu index df145087..d8758a5e 100644 --- a/kernels/src/binary.cu +++ b/kernels/src/binary.cu @@ -1,4 +1,5 @@ #include "binary_op_macros.cuh" +#include #if __CUDA_ARCH__ >= 530 BINARY_OP(__half, badd_f16, x + y) @@ -8,10 +9,14 @@ BINARY_OP(__half, bsub_f16, x - y) #endif BINARY_OP(float, badd_f32, x + y) -BINARY_OP(double, badd_fwd_f64, x + y); +BINARY_OP(double, badd_f64, x + y); +BINARY_OP(uint32_t, badd_u32, x + y); BINARY_OP(float, bdiv_f32, x / y) BINARY_OP(double, bdiv_f64, x / y); +BINARY_OP(uint32_t, bdiv_u32, x / y); BINARY_OP(float, bmul_f32, x * y) BINARY_OP(double, bmul_f64, x * y); +BINARY_OP(uint32_t, bmul_u32, x * y); BINARY_OP(float, bsub_f32, x - y) BINARY_OP(double, bsub_f64, x - y); +BINARY_OP(uint32_t, bsub_u32, x - y); diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index c11504d7..c3afbe44 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -358,6 +358,15 @@ impl CudaStorage { unsafe { func.launch(cfg, params) }?; CudaStorageSlice::F64(out) } + (CudaStorageSlice::U32(lhs), CudaStorageSlice::U32(rhs)) => { + // SAFETY: Set later by running the kernel. + let func = dev.get_or_load_func(B::KERNEL_U32, kernels::BINARY)?; + let out = unsafe { dev.alloc::(elem_count) }?; + let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::U32(out) + } // The dtypes should have been checked at this point so this is an internal error. _ => return Err(CudaError::InternalError("dtype mismatch in binary op")), }; diff --git a/src/op.rs b/src/op.rs index 40be9f4c..0f0f5ee4 100644 --- a/src/op.rs +++ b/src/op.rs @@ -47,6 +47,7 @@ pub(crate) trait BinaryOp { // contiguous case separately as it's easy to optimize things out there. const KERNEL_F32: &'static str; const KERNEL_F64: &'static str; + const KERNEL_U32: &'static str; fn f32(v1: f32, v2: f32) -> f32; fn f64(v1: f64, v2: f64) -> f64; fn u32(v1: u32, v2: u32) -> u32; @@ -65,6 +66,7 @@ impl BinaryOp for Add { const NAME: &'static str = "add"; const KERNEL_F32: &'static str = "badd_f32"; const KERNEL_F64: &'static str = "badd_f64"; + const KERNEL_U32: &'static str = "badd_u32"; fn f32(v1: f32, v2: f32) -> f32 { v1 + v2 } @@ -80,6 +82,7 @@ impl BinaryOp for Sub { const NAME: &'static str = "sub"; const KERNEL_F32: &'static str = "bsub_f32"; const KERNEL_F64: &'static str = "bsub_f64"; + const KERNEL_U32: &'static str = "bsub_u32"; fn f32(v1: f32, v2: f32) -> f32 { v1 - v2 } @@ -95,6 +98,7 @@ impl BinaryOp for Mul { const NAME: &'static str = "mul"; const KERNEL_F32: &'static str = "bmul_f32"; const KERNEL_F64: &'static str = "bmul_f64"; + const KERNEL_U32: &'static str = "bmul_u32"; fn f32(v1: f32, v2: f32) -> f32 { v1 * v2 } @@ -110,6 +114,7 @@ impl BinaryOp for Div { const NAME: &'static str = "div"; const KERNEL_F32: &'static str = "bdiv_f32"; const KERNEL_F64: &'static str = "bdiv_f64"; + const KERNEL_U32: &'static str = "bdiv_u32"; fn f32(v1: f32, v2: f32) -> f32 { v1 / v2 }