mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Cuda backend optimization (#1886)
* Attempt at making the kernel faster. * Also adapt the cast kernels. * Also apply to binary ops.
This commit is contained in:
@ -11,6 +11,31 @@ use cudarc::driver::{
|
||||
use half::{bf16, f16};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
enum SlicePtrOrNull<T> {
|
||||
Ptr(CudaSlice<T>),
|
||||
Null,
|
||||
}
|
||||
|
||||
unsafe impl<T: DeviceRepr> DeviceRepr for &SlicePtrOrNull<T> {
|
||||
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
|
||||
match self {
|
||||
SlicePtrOrNull::Ptr(slice) => slice.as_kernel_param(),
|
||||
SlicePtrOrNull::Null => 0usize.as_kernel_param(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SlicePtrOrNull<usize> {
|
||||
fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
||||
let ds = if l.is_contiguous() {
|
||||
SlicePtrOrNull::Null
|
||||
} else {
|
||||
SlicePtrOrNull::Ptr(dev.htod_copy([l.dims(), l.stride()].concat()).w()?)
|
||||
};
|
||||
Ok(ds)
|
||||
}
|
||||
}
|
||||
|
||||
/// cudarc related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum CudaError {
|
||||
@ -564,7 +589,7 @@ impl Map1 for Affine {
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
@ -596,7 +621,7 @@ impl Map1 for Elu {
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
@ -719,7 +744,7 @@ impl Map1 for Powf {
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
@ -852,7 +877,7 @@ impl<U: UnaryOpT> Map1 for U {
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
@ -1402,9 +1427,14 @@ impl<U: crate::op::BinaryOpT> Map2 for U {
|
||||
let dims = shape.dims();
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let dims_and_strides = dev
|
||||
.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
||||
.w()?;
|
||||
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
||||
SlicePtrOrNull::Null
|
||||
} else {
|
||||
SlicePtrOrNull::Ptr(
|
||||
dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
||||
.w()?,
|
||||
)
|
||||
};
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::BINARY)?;
|
||||
@ -1431,9 +1461,14 @@ impl Map2Any for Cmp {
|
||||
let dims = shape.dims();
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let dims_and_strides = dev
|
||||
.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
||||
.w()?;
|
||||
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
||||
SlicePtrOrNull::Null
|
||||
} else {
|
||||
SlicePtrOrNull::Ptr(
|
||||
dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
||||
.w()?,
|
||||
)
|
||||
};
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let name = match self.0 {
|
||||
@ -1640,7 +1675,7 @@ impl BackendStorage for CudaStorage {
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let dev = self.device();
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let start_o = layout.start_offset();
|
||||
// This returns an i64 rather than a &i64, this is useful to get around some temporary
|
||||
// lifetime issue and is safe as long as self.slice does not go out of scope before inp
|
||||
@ -2215,7 +2250,7 @@ impl BackendStorage for CudaStorage {
|
||||
}
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let dev = &self.device;
|
||||
let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?;
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, src_l)?;
|
||||
match (&self.slice, &mut dst.slice) {
|
||||
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
|
@ -13,7 +13,7 @@ extern "C" __global__ void FN_NAME( \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
if (is_contiguous(num_dims, dims, strides)) { \
|
||||
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = inp ? inp[i] : out[i]; \
|
||||
out[i] = x * mul + add; \
|
||||
|
@ -12,8 +12,8 @@ extern "C" __global__ void FN_NAME( \
|
||||
const size_t *dims = dims_and_strides; \
|
||||
const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \
|
||||
const size_t *rhs_strides = dims_and_strides + 2 * num_dims; \
|
||||
bool lhs_cont = is_contiguous(num_dims, dims, lhs_strides); \
|
||||
bool rhs_cont = is_contiguous(num_dims, dims, rhs_strides); \
|
||||
bool lhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, lhs_strides); \
|
||||
bool rhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, rhs_strides); \
|
||||
if (lhs_cont && rhs_cont) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = lhs[i]; \
|
||||
|
@ -11,7 +11,7 @@ __device__ void cast_(
|
||||
) {
|
||||
const size_t *dims = info;
|
||||
const size_t *strides = info + num_dims;
|
||||
if (is_contiguous(num_dims, dims, strides)) {
|
||||
if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
out[i] = inp[i];
|
||||
}
|
||||
@ -34,7 +34,7 @@ __device__ void cast_through(
|
||||
) {
|
||||
const size_t *dims = info;
|
||||
const size_t *strides = info + num_dims;
|
||||
if (is_contiguous(num_dims, dims, strides)) {
|
||||
if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
out[i] = static_cast<T>(static_cast<I>(inp[i]));
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ extern "C" __global__ void FN_NAME( \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
if (is_contiguous(num_dims, dims, strides)) { \
|
||||
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = inp ? inp[i] : out[i]; \
|
||||
out[i] = FUNC; \
|
||||
@ -71,7 +71,7 @@ extern "C" __global__ void FN_NAME( \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
if (is_contiguous(num_dims, dims, strides)) { \
|
||||
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = inp ? inp[i] : out[i]; \
|
||||
out[i] = FUNC; \
|
||||
|
Reference in New Issue
Block a user