From b219903d0f9ee52f70397c7e9aa4df323b89a700 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 20 Mar 2024 18:32:55 +0100 Subject: [PATCH] Cuda backend optimization (#1886) * Attempt at making the kernel faster. * Also adapt the cast kernels. * Also apply to binary ops. --- candle-core/src/cuda_backend.rs | 59 ++++++++++++++++++++----- candle-kernels/src/affine.cu | 2 +- candle-kernels/src/binary_op_macros.cuh | 4 +- candle-kernels/src/cast.cu | 4 +- candle-kernels/src/unary.cu | 4 +- 5 files changed, 54 insertions(+), 19 deletions(-) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 52d1b558..8954fc33 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -11,6 +11,31 @@ use cudarc::driver::{ use half::{bf16, f16}; use std::sync::{Arc, Mutex}; +enum SlicePtrOrNull { + Ptr(CudaSlice), + Null, +} + +unsafe impl DeviceRepr for &SlicePtrOrNull { + fn as_kernel_param(&self) -> *mut std::ffi::c_void { + match self { + SlicePtrOrNull::Ptr(slice) => slice.as_kernel_param(), + SlicePtrOrNull::Null => 0usize.as_kernel_param(), + } + } +} + +impl SlicePtrOrNull { + fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result { + let ds = if l.is_contiguous() { + SlicePtrOrNull::Null + } else { + SlicePtrOrNull::Ptr(dev.htod_copy([l.dims(), l.stride()].concat()).w()?) + }; + Ok(ds) + } +} + /// cudarc related errors #[derive(thiserror::Error, Debug)] pub enum CudaError { @@ -564,7 +589,7 @@ impl Map1 for Affine { let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); - let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("affine"), kernels::AFFINE)?; // SAFETY: Set later by running the kernel. @@ -596,7 +621,7 @@ impl Map1 for Elu { let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); - let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("uelu"), kernels::UNARY)?; // SAFETY: Set later by running the kernel. @@ -719,7 +744,7 @@ impl Map1 for Powf { let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); - let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("upowf"), kernels::UNARY)?; // SAFETY: Set later by running the kernel. @@ -852,7 +877,7 @@ impl Map1 for U { let dims = shape.dims(); let el_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); - let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::UNARY)?; // SAFETY: Set later by running the kernel. @@ -1402,9 +1427,14 @@ impl Map2 for U { let dims = shape.dims(); let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); - let dims_and_strides = dev - .htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat()) - .w()?; + let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() { + SlicePtrOrNull::Null + } else { + SlicePtrOrNull::Ptr( + dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat()) + .w()?, + ) + }; let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::BINARY)?; @@ -1431,9 +1461,14 @@ impl Map2Any for Cmp { let dims = shape.dims(); let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); - let dims_and_strides = dev - .htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat()) - .w()?; + let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() { + SlicePtrOrNull::Null + } else { + SlicePtrOrNull::Ptr( + dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat()) + .w()?, + ) + }; let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let name = match self.0 { @@ -1640,7 +1675,7 @@ impl BackendStorage for CudaStorage { let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let dev = self.device(); - let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let start_o = layout.start_offset(); // This returns an i64 rather than a &i64, this is useful to get around some temporary // lifetime issue and is safe as long as self.slice does not go out of scope before inp @@ -2215,7 +2250,7 @@ impl BackendStorage for CudaStorage { } let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = &self.device; - let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?; + let ds = SlicePtrOrNull::params_from_layout(dev, src_l)?; match (&self.slice, &mut dst.slice) { (CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); diff --git a/candle-kernels/src/affine.cu b/candle-kernels/src/affine.cu index 152b9463..540d0819 100644 --- a/candle-kernels/src/affine.cu +++ b/candle-kernels/src/affine.cu @@ -13,7 +13,7 @@ extern "C" __global__ void FN_NAME( \ ) { \ const size_t *dims = info; \ const size_t *strides = info + num_dims; \ - if (is_contiguous(num_dims, dims, strides)) { \ + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ TYPENAME x = inp ? inp[i] : out[i]; \ out[i] = x * mul + add; \ diff --git a/candle-kernels/src/binary_op_macros.cuh b/candle-kernels/src/binary_op_macros.cuh index 05d0c3df..9cb00874 100644 --- a/candle-kernels/src/binary_op_macros.cuh +++ b/candle-kernels/src/binary_op_macros.cuh @@ -12,8 +12,8 @@ extern "C" __global__ void FN_NAME( \ const size_t *dims = dims_and_strides; \ const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \ const size_t *rhs_strides = dims_and_strides + 2 * num_dims; \ - bool lhs_cont = is_contiguous(num_dims, dims, lhs_strides); \ - bool rhs_cont = is_contiguous(num_dims, dims, rhs_strides); \ + bool lhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, lhs_strides); \ + bool rhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, rhs_strides); \ if (lhs_cont && rhs_cont) { \ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ TYPENAME x = lhs[i]; \ diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index 024642c6..2fe85e1c 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -11,7 +11,7 @@ __device__ void cast_( ) { const size_t *dims = info; const size_t *strides = info + num_dims; - if (is_contiguous(num_dims, dims, strides)) { + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { out[i] = inp[i]; } @@ -34,7 +34,7 @@ __device__ void cast_through( ) { const size_t *dims = info; const size_t *strides = info + num_dims; - if (is_contiguous(num_dims, dims, strides)) { + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { out[i] = static_cast(static_cast(inp[i])); } diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index 74ba1fac..13489897 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -13,7 +13,7 @@ extern "C" __global__ void FN_NAME( \ ) { \ const size_t *dims = info; \ const size_t *strides = info + num_dims; \ - if (is_contiguous(num_dims, dims, strides)) { \ + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ TYPENAME x = inp ? inp[i] : out[i]; \ out[i] = FUNC; \ @@ -71,7 +71,7 @@ extern "C" __global__ void FN_NAME( \ ) { \ const size_t *dims = info; \ const size_t *strides = info + num_dims; \ - if (is_contiguous(num_dims, dims, strides)) { \ + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ TYPENAME x = inp ? inp[i] : out[i]; \ out[i] = FUNC; \