mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Cuda backend optimization (#1886)
* Attempt at making the kernel faster. * Also adapt the cast kernels. * Also apply to binary ops.
This commit is contained in:
@ -11,6 +11,31 @@ use cudarc::driver::{
|
|||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
enum SlicePtrOrNull<T> {
|
||||||
|
Ptr(CudaSlice<T>),
|
||||||
|
Null,
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl<T: DeviceRepr> DeviceRepr for &SlicePtrOrNull<T> {
|
||||||
|
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
|
||||||
|
match self {
|
||||||
|
SlicePtrOrNull::Ptr(slice) => slice.as_kernel_param(),
|
||||||
|
SlicePtrOrNull::Null => 0usize.as_kernel_param(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SlicePtrOrNull<usize> {
|
||||||
|
fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
||||||
|
let ds = if l.is_contiguous() {
|
||||||
|
SlicePtrOrNull::Null
|
||||||
|
} else {
|
||||||
|
SlicePtrOrNull::Ptr(dev.htod_copy([l.dims(), l.stride()].concat()).w()?)
|
||||||
|
};
|
||||||
|
Ok(ds)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// cudarc related errors
|
/// cudarc related errors
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum CudaError {
|
pub enum CudaError {
|
||||||
@ -564,7 +589,7 @@ impl Map1 for Affine {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), kernels::AFFINE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), kernels::AFFINE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
@ -596,7 +621,7 @@ impl Map1 for Elu {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), kernels::UNARY)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), kernels::UNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
@ -719,7 +744,7 @@ impl Map1 for Powf {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), kernels::UNARY)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), kernels::UNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
@ -852,7 +877,7 @@ impl<U: UnaryOpT> Map1 for U {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::UNARY)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::UNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
@ -1402,9 +1427,14 @@ impl<U: crate::op::BinaryOpT> Map2 for U {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||||
let dims_and_strides = dev
|
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
||||||
.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
SlicePtrOrNull::Null
|
||||||
.w()?;
|
} else {
|
||||||
|
SlicePtrOrNull::Ptr(
|
||||||
|
dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
||||||
|
.w()?,
|
||||||
|
)
|
||||||
|
};
|
||||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::BINARY)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::BINARY)?;
|
||||||
@ -1431,9 +1461,14 @@ impl Map2Any for Cmp {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||||
let dims_and_strides = dev
|
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
||||||
.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
SlicePtrOrNull::Null
|
||||||
.w()?;
|
} else {
|
||||||
|
SlicePtrOrNull::Ptr(
|
||||||
|
dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
||||||
|
.w()?,
|
||||||
|
)
|
||||||
|
};
|
||||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||||
let name = match self.0 {
|
let name = match self.0 {
|
||||||
@ -1640,7 +1675,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
let dev = self.device();
|
let dev = self.device();
|
||||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||||
let start_o = layout.start_offset();
|
let start_o = layout.start_offset();
|
||||||
// This returns an i64 rather than a &i64, this is useful to get around some temporary
|
// This returns an i64 rather than a &i64, this is useful to get around some temporary
|
||||||
// lifetime issue and is safe as long as self.slice does not go out of scope before inp
|
// lifetime issue and is safe as long as self.slice does not go out of scope before inp
|
||||||
@ -2215,7 +2250,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
}
|
}
|
||||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||||
let dev = &self.device;
|
let dev = &self.device;
|
||||||
let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?;
|
let ds = SlicePtrOrNull::params_from_layout(dev, src_l)?;
|
||||||
match (&self.slice, &mut dst.slice) {
|
match (&self.slice, &mut dst.slice) {
|
||||||
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
||||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||||
|
@ -13,7 +13,7 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
) { \
|
) { \
|
||||||
const size_t *dims = info; \
|
const size_t *dims = info; \
|
||||||
const size_t *strides = info + num_dims; \
|
const size_t *strides = info + num_dims; \
|
||||||
if (is_contiguous(num_dims, dims, strides)) { \
|
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
|
||||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||||
TYPENAME x = inp ? inp[i] : out[i]; \
|
TYPENAME x = inp ? inp[i] : out[i]; \
|
||||||
out[i] = x * mul + add; \
|
out[i] = x * mul + add; \
|
||||||
|
@ -12,8 +12,8 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
const size_t *dims = dims_and_strides; \
|
const size_t *dims = dims_and_strides; \
|
||||||
const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \
|
const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \
|
||||||
const size_t *rhs_strides = dims_and_strides + 2 * num_dims; \
|
const size_t *rhs_strides = dims_and_strides + 2 * num_dims; \
|
||||||
bool lhs_cont = is_contiguous(num_dims, dims, lhs_strides); \
|
bool lhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, lhs_strides); \
|
||||||
bool rhs_cont = is_contiguous(num_dims, dims, rhs_strides); \
|
bool rhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, rhs_strides); \
|
||||||
if (lhs_cont && rhs_cont) { \
|
if (lhs_cont && rhs_cont) { \
|
||||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||||
TYPENAME x = lhs[i]; \
|
TYPENAME x = lhs[i]; \
|
||||||
|
@ -11,7 +11,7 @@ __device__ void cast_(
|
|||||||
) {
|
) {
|
||||||
const size_t *dims = info;
|
const size_t *dims = info;
|
||||||
const size_t *strides = info + num_dims;
|
const size_t *strides = info + num_dims;
|
||||||
if (is_contiguous(num_dims, dims, strides)) {
|
if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
|
||||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||||
out[i] = inp[i];
|
out[i] = inp[i];
|
||||||
}
|
}
|
||||||
@ -34,7 +34,7 @@ __device__ void cast_through(
|
|||||||
) {
|
) {
|
||||||
const size_t *dims = info;
|
const size_t *dims = info;
|
||||||
const size_t *strides = info + num_dims;
|
const size_t *strides = info + num_dims;
|
||||||
if (is_contiguous(num_dims, dims, strides)) {
|
if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
|
||||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||||
out[i] = static_cast<T>(static_cast<I>(inp[i]));
|
out[i] = static_cast<T>(static_cast<I>(inp[i]));
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,7 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
) { \
|
) { \
|
||||||
const size_t *dims = info; \
|
const size_t *dims = info; \
|
||||||
const size_t *strides = info + num_dims; \
|
const size_t *strides = info + num_dims; \
|
||||||
if (is_contiguous(num_dims, dims, strides)) { \
|
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
|
||||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||||
TYPENAME x = inp ? inp[i] : out[i]; \
|
TYPENAME x = inp ? inp[i] : out[i]; \
|
||||||
out[i] = FUNC; \
|
out[i] = FUNC; \
|
||||||
@ -71,7 +71,7 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
) { \
|
) { \
|
||||||
const size_t *dims = info; \
|
const size_t *dims = info; \
|
||||||
const size_t *strides = info + num_dims; \
|
const size_t *strides = info + num_dims; \
|
||||||
if (is_contiguous(num_dims, dims, strides)) { \
|
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
|
||||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||||
TYPENAME x = inp ? inp[i] : out[i]; \
|
TYPENAME x = inp ? inp[i] : out[i]; \
|
||||||
out[i] = FUNC; \
|
out[i] = FUNC; \
|
||||||
|
Reference in New Issue
Block a user