diff --git a/kernels/src/affine.cu b/kernels/src/affine.cu index da2cc46d..af25aba1 100644 --- a/kernels/src/affine.cu +++ b/kernels/src/affine.cu @@ -1,27 +1,42 @@ +#include "cuda_utils.cuh" + extern "C" __global__ void affine_f32( const size_t numel, + const size_t num_dims, + const size_t *info, const float *x, float *y, const float mul, const float add ) { + const size_t *dims = info; + const size_t *strides = info + num_dims; unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; if (i >= numel) { return; } - y[i] = x[i] * mul + add; + // This is likely to be very very slow, we should either optimize the contiguous case + // as a separate kernel, proceed by block, improve the stride computations (and probably + // do all of these). + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); + y[strided_i] = x[i] * mul + add; } extern "C" __global__ void affine_f64( const size_t numel, + const size_t num_dims, + const size_t *info, const double *x, double *y, const double mul, const double add ) { + const size_t *dims = info; + const size_t *strides = info + num_dims; unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; if (i >= numel) { return; } - y[i] = x[i] * mul + add; + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); + y[strided_i] = x[i] * mul + add; } diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 378db1ce..66556f40 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -135,19 +135,17 @@ impl CudaStorage { mul: f64, add: f64, ) -> Result { - if !shape.is_contiguous(stride) { - return Err(CudaError::RequiresContiguous { op: "affine" }); - } - - let elem_count = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(elem_count as u32); + let dims = shape.dims(); + let el_count = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = self.device(); + let ds = dev.0.htod_copy([dims, stride].concat())?; match self { Self::F32(arg) => { let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.0.alloc::(elem_count) }?; - let params = (elem_count, arg, &out, mul as f32, add as f32); + let out = unsafe { dev.0.alloc::(el_count) }?; + let params = (el_count, dims.len(), &ds, arg, &out, mul as f32, add as f32); // SAFETY: ffi. unsafe { func.launch(cfg, params) }?; Ok(Self::F32(out)) @@ -155,8 +153,8 @@ impl CudaStorage { Self::F64(arg) => { let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.0.alloc::(elem_count) }?; - let params = (elem_count, arg, &out, mul, add); + let out = unsafe { dev.0.alloc::(el_count) }?; + let params = (el_count, dims.len(), &ds, arg, &out, mul, add); // SAFETY: ffi. unsafe { func.launch(cfg, params) }?; Ok(Self::F64(out)) @@ -209,13 +207,12 @@ impl CudaStorage { let dims = shape.dims(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); let dev = self.device(); - let dims_and_strides = [dims, lhs_stride, rhs_stride].concat(); + let dims_and_strides = dev.0.htod_copy([dims, lhs_stride, rhs_stride].concat())?; match (self, rhs) { (Self::F32(lhs), Self::F32(rhs)) => { let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.0.alloc::(elem_count) }?; - let dims_and_strides = dev.0.htod_copy(dims_and_strides)?; let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); // SAFETY: ffi unsafe { func.launch(cfg, params) }?; @@ -225,7 +222,6 @@ impl CudaStorage { // SAFETY: Set later by running the kernel. let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?; let out = unsafe { dev.0.alloc::(elem_count) }?; - let dims_and_strides = dev.0.htod_copy(dims_and_strides)?; let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); // SAFETY: ffi unsafe { func.launch(cfg, params) }?;