diff --git a/kernels/src/unary.cu b/kernels/src/unary.cu index 7683d0d3..5bee3725 100644 --- a/kernels/src/unary.cu +++ b/kernels/src/unary.cu @@ -18,11 +18,14 @@ extern "C" __global__ void FN_NAME( \ } \ #if __CUDA_ARCH__ >= 530 +UNARY_OP(__half, ucopy_f16, x) UNARY_OP(__half, uneg_f16, -x) UNARY_OP(__half, usqr_f16, x*x) UNARY_OP(__half, usqrt_f16, sqrtg(x)) #endif +UNARY_OP(float, ucopy_f32, x) +UNARY_OP(float, ucopy_f64, x) UNARY_OP(float, uneg_f32, -x) UNARY_OP(float, uneg_f64, -x) UNARY_OP(float, usqr_f32, x*x) diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 7b6dd655..a8beba18 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -381,4 +381,42 @@ impl CudaStorage { let device = dev.clone(); Ok(Self { slice, device }) } + + pub(crate) fn copy_strided_src( + &self, + dst: &mut Self, + src_shape: &Shape, + src_stride: &[usize], + dst_offset: usize, + ) -> Result<()> { + let dims = src_shape.dims(); + let el_count = src_shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el_count as u32); + let dev = &self.device; + let ds = dev.htod_copy([dims, src_stride].concat())?; + match (&self.slice, &mut dst.slice) { + (CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => { + let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?; + let mut dst = dst.slice_mut(dst_offset..); + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }? + } + (CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => { + let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?; + let mut dst = dst.slice_mut(dst_offset..); + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + } + _ => { + return Err(CudaError::InternalError( + "dtype mismatch in copy_strided op", + )) + } + } + Ok(()) + } } diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs index fbcfe758..a12bafe3 100644 --- a/src/dummy_cuda_backend.rs +++ b/src/dummy_cuda_backend.rs @@ -85,4 +85,14 @@ impl CudaStorage { ) -> Result { Err(Error::NotCompiledWithCudaSupport) } + + pub(crate) fn copy_strided_src( + &self, + _: &mut Self, + _: &Shape, + _: &[usize], + _: usize, + ) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } } diff --git a/src/storage.rs b/src/storage.rs index a1f5d121..bb2aad4c 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -160,8 +160,8 @@ impl Storage { (Self::Cpu(src), Self::Cpu(dst)) => { src.copy_strided_src(dst, src_shape, src_stride, dst_offset) } - (Self::Cuda(_src), Self::Cuda(_dst)) => { - todo!() + (Self::Cuda(src), Self::Cuda(dst)) => { + Ok(src.copy_strided_src(dst, src_shape, src_stride, dst_offset)?) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), diff --git a/src/tensor.rs b/src/tensor.rs index e9a41937..1e411dcc 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -596,6 +596,9 @@ impl Tensor { } // TODO: Do we want to allow target shape using -1 on some dimensions? + /// Reshape returns a tensor with the target shape provided that the number of elements of the + /// original tensor is the same. This uses a new storage and copies the data over, the returned + /// tensor is always contiguous. pub fn reshape>(&self, shape: S) -> Result { let shape = shape.into(); if shape.elem_count() != self.elem_count() {