diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index a8beba18..edbb7700 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -396,20 +396,29 @@ impl CudaStorage { let ds = dev.htod_copy([dims, src_stride].concat())?; match (&self.slice, &mut dst.slice) { (CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => { - let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?; let mut dst = dst.slice_mut(dst_offset..); - // SAFETY: Set later by running the kernel. - let params = (el_count, dims.len(), &ds, src, &mut dst); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }? + if src_shape.is_contiguous(src_stride) { + dev.dtod_copy(src, &mut dst)? + } else { + let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }? + } } (CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => { - let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?; let mut dst = dst.slice_mut(dst_offset..); - // SAFETY: Set later by running the kernel. - let params = (el_count, dims.len(), &ds, src, &mut dst); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; + if src_shape.is_contiguous(src_stride) { + dev.dtod_copy(src, &mut dst)? + } else { + let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?; + let mut dst = dst.slice_mut(dst_offset..); + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + } } _ => { return Err(CudaError::InternalError(