mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Sketch a fast cuda kernel for reduce-sum. (#109)
* Sketch a fast cuda kernel for reduce-sum. * Sketch the rust support code for the fast sum kernel. * More work on the fast kernel. * Add some testing ground. * A couple fixes for the fast sum kernel.
This commit is contained in:
15
candle-core/examples/cuda_basics.rs
Normal file
15
candle-core/examples/cuda_basics.rs
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle::{Device, Tensor};
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
let device = Device::new_cuda(0)?;
|
||||||
|
let t = Tensor::new(&[[1f32, 2., 3., 4.2]], &device)?;
|
||||||
|
let sum = t.sum(&[0])?;
|
||||||
|
println!("{sum}");
|
||||||
|
let sum = t.sum(&[1])?;
|
||||||
|
println!("{sum}");
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -357,6 +357,7 @@ impl Map1 for Affine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
struct Sum<'a>(&'a [usize]);
|
struct Sum<'a>(&'a [usize]);
|
||||||
impl<'a> Map1 for Sum<'a> {
|
impl<'a> Map1 for Sum<'a> {
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
@ -393,6 +394,56 @@ impl<'a> Map1 for Sum<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
struct FastSum<'a>(&'a [usize]);
|
||||||
|
impl<'a> Map1 for FastSum<'a> {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let src_stride = layout.stride();
|
||||||
|
let src_dims = layout.shape().dims();
|
||||||
|
let src_el: usize = src_dims.iter().product();
|
||||||
|
// Source dims and strides with the sum dims at the end.
|
||||||
|
let mut dims = vec![];
|
||||||
|
let mut stride = vec![];
|
||||||
|
let mut dst_el: usize = 1;
|
||||||
|
for (dim_idx, &d) in src_dims.iter().enumerate() {
|
||||||
|
if !self.0.contains(&dim_idx) {
|
||||||
|
dst_el *= d;
|
||||||
|
dims.push(d);
|
||||||
|
stride.push(src_stride[dim_idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for &dim_idx in self.0.iter() {
|
||||||
|
dims.push(src_dims[dim_idx]);
|
||||||
|
stride.push(src_stride[dim_idx]);
|
||||||
|
}
|
||||||
|
let el_to_sum_per_block = src_el / dst_el;
|
||||||
|
// The reduction loop requires the shared array to be properly initialized and for
|
||||||
|
// this we want the number of threads to be a power of two.
|
||||||
|
let block_dim = usize::min(1024, el_to_sum_per_block).next_power_of_two();
|
||||||
|
let cfg = LaunchConfig {
|
||||||
|
// TODO: Maybe use grid_y if the output is too large?
|
||||||
|
// TODO: Specialized implementation when reducing on no or all dimensions or when
|
||||||
|
// reducing only aggregate a small number of elements together.
|
||||||
|
grid_dim: (dst_el as u32, 1, 1),
|
||||||
|
block_dim: (block_dim as u32, 1, 1),
|
||||||
|
shared_mem_bytes: 0,
|
||||||
|
};
|
||||||
|
let ds = dev.htod_copy([dims.as_slice(), stride.as_slice()].concat())?;
|
||||||
|
let src = &src.slice(layout.start_offset()..);
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>("fast_sum"), kernels::REDUCE)?;
|
||||||
|
let out = dev.alloc_zeros::<T>(dst_el)?;
|
||||||
|
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<U: crate::op::UnaryOp> Map1 for U {
|
impl<U: crate::op::UnaryOp> Map1 for U {
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
&self,
|
&self,
|
||||||
@ -726,7 +777,7 @@ impl CudaStorage {
|
|||||||
|
|
||||||
pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
let slice = Sum(sum_dims).map(&self.slice, &device, layout)?;
|
let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?;
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,6 +3,67 @@
|
|||||||
#include "cuda_utils.cuh"
|
#include "cuda_utils.cuh"
|
||||||
#include<stdint.h>
|
#include<stdint.h>
|
||||||
|
|
||||||
|
const int BLOCK_SIZE = 1024;
|
||||||
|
|
||||||
|
// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32 but
|
||||||
|
// also expect a f32 output so that this can be used for normalization e.g. in softmax.
|
||||||
|
|
||||||
|
// Fast reduce sum kernel, this assumes that the dimensions to loop over are at
|
||||||
|
// the end, each block is responsible for populating one value in the output array.
|
||||||
|
// There are at most 1024 threads per block.
|
||||||
|
template <typename T>
|
||||||
|
__device__ void fast_sum(
|
||||||
|
const size_t src_numel,
|
||||||
|
const size_t el_to_sum_per_block,
|
||||||
|
const size_t num_dims,
|
||||||
|
const size_t *info,
|
||||||
|
const T *src,
|
||||||
|
T *dst
|
||||||
|
) {
|
||||||
|
const size_t *dims = info;
|
||||||
|
const size_t *strides = info + num_dims;
|
||||||
|
|
||||||
|
__shared__ T shr[BLOCK_SIZE];
|
||||||
|
size_t tid = threadIdx.x;
|
||||||
|
size_t dst_id = blockIdx.x;
|
||||||
|
|
||||||
|
shr[tid] = 0.0;
|
||||||
|
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||||
|
// to (dst_id + 1) * el_to_sum_per_block.
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||||
|
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||||
|
size_t idx = start_idx + tid;
|
||||||
|
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
// TODO: Fast version for the contiguous case.
|
||||||
|
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||||
|
shr[tid] += src[strided_i];
|
||||||
|
idx += blockDim.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parallel reduction, see the slides:
|
||||||
|
// https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf
|
||||||
|
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
|
||||||
|
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||||
|
__syncthreads();
|
||||||
|
if (tid < s) shr[tid] += shr[tid + s];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tid == 0) atomicAdd(dst + dst_id, shr[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define FAST_SUM_OP(TYPENAME, FN_NAME) \
|
||||||
|
extern "C" __global__ void FN_NAME( \
|
||||||
|
const size_t src_numel, \
|
||||||
|
const size_t el_to_sum_per_block, \
|
||||||
|
const size_t num_dims, \
|
||||||
|
const size_t *info, \
|
||||||
|
const TYPENAME *src, \
|
||||||
|
TYPENAME *dst \
|
||||||
|
) { \
|
||||||
|
fast_sum(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
|
||||||
|
} \
|
||||||
|
|
||||||
#define SUM_OP(TYPENAME, FN_NAME) \
|
#define SUM_OP(TYPENAME, FN_NAME) \
|
||||||
extern "C" __global__ void FN_NAME( \
|
extern "C" __global__ void FN_NAME( \
|
||||||
const size_t numel, \
|
const size_t numel, \
|
||||||
@ -45,12 +106,18 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
|
|
||||||
#if __CUDA_ARCH__ >= 800
|
#if __CUDA_ARCH__ >= 800
|
||||||
SUM_OP(__nv_bfloat16, sum_bf16)
|
SUM_OP(__nv_bfloat16, sum_bf16)
|
||||||
|
FAST_SUM_OP(__nv_bfloat16, fast_sum_bf16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
SUM_OP(__half, sum_f16)
|
SUM_OP(__half, sum_f16)
|
||||||
|
FAST_SUM_OP(__half, fast_sum_f16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
SUM_OP(float, sum_f32)
|
SUM_OP(float, sum_f32)
|
||||||
SUM_OP(double, sum_f64)
|
SUM_OP(double, sum_f64)
|
||||||
SUM_OP(uint32_t, sum_u32)
|
SUM_OP(uint32_t, sum_u32)
|
||||||
|
|
||||||
|
FAST_SUM_OP(float, fast_sum_f32)
|
||||||
|
FAST_SUM_OP(double, fast_sum_f64)
|
||||||
|
FAST_SUM_OP(uint32_t, fast_sum_u32)
|
||||||
|
Reference in New Issue
Block a user