mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Cuda kernels for fast min/max reductions (#203)
* Add the min/max cuda kernels. * Better integration of the cuda kernels.
This commit is contained in:
@ -515,8 +515,8 @@ impl<'a> Map1 for Sum<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct FastSum<'a>(&'a [usize]);
|
struct FastReduce<'a>(&'a [usize], crate::op::ReduceOp);
|
||||||
impl<'a> Map1 for FastSum<'a> {
|
impl<'a> Map1 for FastReduce<'a> {
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
&self,
|
&self,
|
||||||
src: &CudaSlice<T>,
|
src: &CudaSlice<T>,
|
||||||
@ -557,8 +557,14 @@ impl<'a> Map1 for FastSum<'a> {
|
|||||||
.htod_copy([dims.as_slice(), stride.as_slice()].concat())
|
.htod_copy([dims.as_slice(), stride.as_slice()].concat())
|
||||||
.w()?;
|
.w()?;
|
||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("fast_sum"), kernels::REDUCE)?;
|
let name = match self.1 {
|
||||||
let out = dev.alloc_zeros::<T>(dst_el).w()?;
|
crate::op::ReduceOp::Sum => "fast_sum",
|
||||||
|
crate::op::ReduceOp::Min => "fast_min",
|
||||||
|
crate::op::ReduceOp::Max => "fast_max",
|
||||||
|
};
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
|
||||||
|
// SAFETY: filled in by the follow up kernel.
|
||||||
|
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
@ -961,15 +967,9 @@ impl BackendStorage for CudaStorage {
|
|||||||
layout: &Layout,
|
layout: &Layout,
|
||||||
sum_dims: &[usize],
|
sum_dims: &[usize],
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
match op {
|
let device = self.device().clone();
|
||||||
crate::op::ReduceOp::Sum => {
|
let slice = FastReduce(sum_dims, op).map(&self.slice, &device, layout)?;
|
||||||
let device = self.device().clone();
|
Ok(Self { slice, device })
|
||||||
let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?;
|
|
||||||
Ok(Self { slice, device })
|
|
||||||
}
|
|
||||||
crate::op::ReduceOp::Min => Err(CudaError::InternalError("TODO: implement min").into()),
|
|
||||||
crate::op::ReduceOp::Max => Err(CudaError::InternalError("TODO: implement max").into()),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
#include "compatibility.cuh"
|
#include "compatibility.cuh"
|
||||||
|
#include<stdint.h>
|
||||||
|
#include<cmath>
|
||||||
|
|
||||||
// TODO: This is often used to check that the data is contiguous so that
|
// TODO: This is often used to check that the data is contiguous so that
|
||||||
// kernels can be easily mapped. However this only returns true for row
|
// kernels can be easily mapped. However this only returns true for row
|
||||||
@ -140,6 +142,9 @@ __device__ __forceinline__ double absg(double a) { return fabs(a); }
|
|||||||
__device__ __forceinline__ float copysigng(float a, float b) { return copysignf(a, b); }
|
__device__ __forceinline__ float copysigng(float a, float b) { return copysignf(a, b); }
|
||||||
__device__ __forceinline__ double copysigng(double a, double b) { return copysign(a, b); }
|
__device__ __forceinline__ double copysigng(double a, double b) { return copysign(a, b); }
|
||||||
|
|
||||||
|
__device__ __forceinline__ uint32_t ming(uint32_t a, uint32_t b) { return min(a, b); }
|
||||||
|
__device__ __forceinline__ uint32_t maxg(uint32_t a, uint32_t b) { return max(a, b); }
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
__device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); }
|
__device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); }
|
||||||
__device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); }
|
__device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); }
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
// https://people.maths.ox.ac.uk/gilesm/cuda/prac4/reduction.pdf
|
// https://people.maths.ox.ac.uk/gilesm/cuda/prac4/reduction.pdf
|
||||||
#include "cuda_utils.cuh"
|
#include "cuda_utils.cuh"
|
||||||
#include<stdint.h>
|
#include<stdint.h>
|
||||||
|
#include<cmath>
|
||||||
|
|
||||||
const int BLOCK_SIZE = 1024;
|
const int BLOCK_SIZE = 1024;
|
||||||
|
|
||||||
@ -27,7 +28,7 @@ __device__ void fast_sum(
|
|||||||
size_t tid = threadIdx.x;
|
size_t tid = threadIdx.x;
|
||||||
size_t dst_id = blockIdx.x;
|
size_t dst_id = blockIdx.x;
|
||||||
|
|
||||||
shr[tid] = 0.0;
|
shr[tid] = 0;
|
||||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||||
// to (dst_id + 1) * el_to_sum_per_block.
|
// to (dst_id + 1) * el_to_sum_per_block.
|
||||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||||
@ -49,11 +50,113 @@ __device__ void fast_sum(
|
|||||||
if (tid < s) shr[tid] += shr[tid + s];
|
if (tid < s) shr[tid] += shr[tid + s];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tid == 0) atomicAdd(dst + dst_id, shr[0]);
|
if (tid == 0) dst[dst_id] = shr[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
#define FAST_SUM_OP(TYPENAME, FN_NAME) \
|
template <typename T>
|
||||||
extern "C" __global__ void FN_NAME( \
|
__device__ void fast_max(
|
||||||
|
const size_t src_numel,
|
||||||
|
const size_t el_to_sum_per_block,
|
||||||
|
const size_t num_dims,
|
||||||
|
const size_t *info,
|
||||||
|
const T *src,
|
||||||
|
T *dst
|
||||||
|
) {
|
||||||
|
const size_t *dims = info;
|
||||||
|
const size_t *strides = info + num_dims;
|
||||||
|
|
||||||
|
__shared__ T shr[BLOCK_SIZE];
|
||||||
|
size_t tid = threadIdx.x;
|
||||||
|
size_t dst_id = blockIdx.x;
|
||||||
|
|
||||||
|
shr[tid] = -INFINITY;
|
||||||
|
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||||
|
// to (dst_id + 1) * el_to_sum_per_block.
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||||
|
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||||
|
size_t idx = start_idx + tid;
|
||||||
|
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
// TODO: Fast version for the contiguous case.
|
||||||
|
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||||
|
shr[tid] = maxg(shr[tid], src[strided_i]);
|
||||||
|
idx += blockDim.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parallel reduction, see the slides:
|
||||||
|
// https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf
|
||||||
|
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
|
||||||
|
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||||
|
__syncthreads();
|
||||||
|
if (tid < s) shr[tid] = maxg(shr[tid], shr[tid + s]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tid == 0) dst[dst_id] = shr[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ void fast_min(
|
||||||
|
const size_t src_numel,
|
||||||
|
const size_t el_to_sum_per_block,
|
||||||
|
const size_t num_dims,
|
||||||
|
const size_t *info,
|
||||||
|
const T *src,
|
||||||
|
T *dst
|
||||||
|
) {
|
||||||
|
const size_t *dims = info;
|
||||||
|
const size_t *strides = info + num_dims;
|
||||||
|
|
||||||
|
__shared__ T shr[BLOCK_SIZE];
|
||||||
|
size_t tid = threadIdx.x;
|
||||||
|
size_t dst_id = blockIdx.x;
|
||||||
|
|
||||||
|
shr[tid] = INFINITY;
|
||||||
|
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||||
|
// to (dst_id + 1) * el_to_sum_per_block.
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||||
|
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||||
|
size_t idx = start_idx + tid;
|
||||||
|
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
// TODO: Fast version for the contiguous case.
|
||||||
|
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||||
|
shr[tid] = ming(shr[tid], src[strided_i]);
|
||||||
|
idx += blockDim.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parallel reduction, see the slides:
|
||||||
|
// https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf
|
||||||
|
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
|
||||||
|
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||||
|
__syncthreads();
|
||||||
|
if (tid < s) shr[tid] = ming(shr[tid], shr[tid + s]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tid == 0) dst[dst_id] = shr[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
#define FAST_OP(TYPENAME, MIN_NAME, MAX_NAME, SUM_NAME) \
|
||||||
|
extern "C" __global__ void MIN_NAME( \
|
||||||
|
const size_t src_numel, \
|
||||||
|
const size_t el_to_sum_per_block, \
|
||||||
|
const size_t num_dims, \
|
||||||
|
const size_t *info, \
|
||||||
|
const TYPENAME *src, \
|
||||||
|
TYPENAME *dst \
|
||||||
|
) { \
|
||||||
|
fast_min(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
|
||||||
|
} \
|
||||||
|
extern "C" __global__ void MAX_NAME( \
|
||||||
|
const size_t src_numel, \
|
||||||
|
const size_t el_to_sum_per_block, \
|
||||||
|
const size_t num_dims, \
|
||||||
|
const size_t *info, \
|
||||||
|
const TYPENAME *src, \
|
||||||
|
TYPENAME *dst \
|
||||||
|
) { \
|
||||||
|
fast_max(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
|
||||||
|
} \
|
||||||
|
extern "C" __global__ void SUM_NAME( \
|
||||||
const size_t src_numel, \
|
const size_t src_numel, \
|
||||||
const size_t el_to_sum_per_block, \
|
const size_t el_to_sum_per_block, \
|
||||||
const size_t num_dims, \
|
const size_t num_dims, \
|
||||||
@ -106,18 +209,18 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
|
|
||||||
#if __CUDA_ARCH__ >= 800
|
#if __CUDA_ARCH__ >= 800
|
||||||
SUM_OP(__nv_bfloat16, sum_bf16)
|
SUM_OP(__nv_bfloat16, sum_bf16)
|
||||||
FAST_SUM_OP(__nv_bfloat16, fast_sum_bf16)
|
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_sum_bf16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
SUM_OP(__half, sum_f16)
|
SUM_OP(__half, sum_f16)
|
||||||
FAST_SUM_OP(__half, fast_sum_f16)
|
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_sum_f16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
SUM_OP(float, sum_f32)
|
SUM_OP(float, sum_f32)
|
||||||
SUM_OP(double, sum_f64)
|
SUM_OP(double, sum_f64)
|
||||||
SUM_OP(uint32_t, sum_u32)
|
SUM_OP(uint32_t, sum_u32)
|
||||||
|
|
||||||
FAST_SUM_OP(float, fast_sum_f32)
|
FAST_OP(float, fast_min_f32, fast_max_f32, fast_sum_f32)
|
||||||
FAST_SUM_OP(double, fast_sum_f64)
|
FAST_OP(double, fast_min_f64, fast_max_f64, fast_sum_f64)
|
||||||
FAST_SUM_OP(uint32_t, fast_sum_u32)
|
FAST_OP(uint32_t, fast_min_u32, fast_max_u32, fast_sum_u32)
|
||||||
|
Reference in New Issue
Block a user