Cuda backend optimization (#1886)

* Attempt at making the kernel faster.

* Also adapt the cast kernels.

* Also apply to binary ops.
This commit is contained in:
Laurent Mazare
2024-03-20 18:32:55 +01:00
committed by GitHub
parent 469635a3eb
commit b219903d0f
5 changed files with 54 additions and 19 deletions

View File

@ -11,6 +11,31 @@ use cudarc::driver::{
use half::{bf16, f16};
use std::sync::{Arc, Mutex};
enum SlicePtrOrNull<T> {
Ptr(CudaSlice<T>),
Null,
}
unsafe impl<T: DeviceRepr> DeviceRepr for &SlicePtrOrNull<T> {
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
match self {
SlicePtrOrNull::Ptr(slice) => slice.as_kernel_param(),
SlicePtrOrNull::Null => 0usize.as_kernel_param(),
}
}
}
impl SlicePtrOrNull<usize> {
fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
let ds = if l.is_contiguous() {
SlicePtrOrNull::Null
} else {
SlicePtrOrNull::Ptr(dev.htod_copy([l.dims(), l.stride()].concat()).w()?)
};
Ok(ds)
}
}
/// cudarc related errors
#[derive(thiserror::Error, Debug)]
pub enum CudaError {
@ -564,7 +589,7 @@ impl Map1 for Affine {
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
@ -596,7 +621,7 @@ impl Map1 for Elu {
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
@ -719,7 +744,7 @@ impl Map1 for Powf {
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
@ -852,7 +877,7 @@ impl<U: UnaryOpT> Map1 for U {
let dims = shape.dims();
let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
@ -1402,9 +1427,14 @@ impl<U: crate::op::BinaryOpT> Map2 for U {
let dims = shape.dims();
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dims_and_strides = dev
.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
.w()?;
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
SlicePtrOrNull::Null
} else {
SlicePtrOrNull::Ptr(
dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
.w()?,
)
};
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::BINARY)?;
@ -1431,9 +1461,14 @@ impl Map2Any for Cmp {
let dims = shape.dims();
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dims_and_strides = dev
.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
.w()?;
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
SlicePtrOrNull::Null
} else {
SlicePtrOrNull::Ptr(
dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
.w()?,
)
};
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let name = match self.0 {
@ -1640,7 +1675,7 @@ impl BackendStorage for CudaStorage {
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device();
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
let start_o = layout.start_offset();
// This returns an i64 rather than a &i64, this is useful to get around some temporary
// lifetime issue and is safe as long as self.slice does not go out of scope before inp
@ -2215,7 +2250,7 @@ impl BackendStorage for CudaStorage {
}
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = &self.device;
let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?;
let ds = SlicePtrOrNull::params_from_layout(dev, src_l)?;
match (&self.slice, &mut dst.slice) {
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);

View File

@ -13,7 +13,7 @@ extern "C" __global__ void FN_NAME( \
) { \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
if (is_contiguous(num_dims, dims, strides)) { \
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
TYPENAME x = inp ? inp[i] : out[i]; \
out[i] = x * mul + add; \

View File

@ -12,8 +12,8 @@ extern "C" __global__ void FN_NAME( \
const size_t *dims = dims_and_strides; \
const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \
const size_t *rhs_strides = dims_and_strides + 2 * num_dims; \
bool lhs_cont = is_contiguous(num_dims, dims, lhs_strides); \
bool rhs_cont = is_contiguous(num_dims, dims, rhs_strides); \
bool lhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, lhs_strides); \
bool rhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, rhs_strides); \
if (lhs_cont && rhs_cont) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
TYPENAME x = lhs[i]; \

View File

@ -11,7 +11,7 @@ __device__ void cast_(
) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
if (is_contiguous(num_dims, dims, strides)) {
if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
out[i] = inp[i];
}
@ -34,7 +34,7 @@ __device__ void cast_through(
) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
if (is_contiguous(num_dims, dims, strides)) {
if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
out[i] = static_cast<T>(static_cast<I>(inp[i]));
}

View File

@ -13,7 +13,7 @@ extern "C" __global__ void FN_NAME( \
) { \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
if (is_contiguous(num_dims, dims, strides)) { \
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
TYPENAME x = inp ? inp[i] : out[i]; \
out[i] = FUNC; \
@ -71,7 +71,7 @@ extern "C" __global__ void FN_NAME( \
) { \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
if (is_contiguous(num_dims, dims, strides)) { \
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
TYPENAME x = inp ? inp[i] : out[i]; \
out[i] = FUNC; \