diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 07d354b6..b74137f3 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -515,8 +515,8 @@ impl<'a> Map1 for Sum<'a> { } } -struct FastSum<'a>(&'a [usize]); -impl<'a> Map1 for FastSum<'a> { +struct FastReduce<'a>(&'a [usize], crate::op::ReduceOp); +impl<'a> Map1 for FastReduce<'a> { fn f( &self, src: &CudaSlice, @@ -557,8 +557,14 @@ impl<'a> Map1 for FastSum<'a> { .htod_copy([dims.as_slice(), stride.as_slice()].concat()) .w()?; let src = &src.slice(layout.start_offset()..); - let func = dev.get_or_load_func(&kernel_name::("fast_sum"), kernels::REDUCE)?; - let out = dev.alloc_zeros::(dst_el).w()?; + let name = match self.1 { + crate::op::ReduceOp::Sum => "fast_sum", + crate::op::ReduceOp::Min => "fast_min", + crate::op::ReduceOp::Max => "fast_max", + }; + let func = dev.get_or_load_func(&kernel_name::(name), kernels::REDUCE)?; + // SAFETY: filled in by the follow up kernel. + let out = unsafe { dev.alloc::(dst_el) }.w()?; let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; @@ -961,15 +967,9 @@ impl BackendStorage for CudaStorage { layout: &Layout, sum_dims: &[usize], ) -> Result { - match op { - crate::op::ReduceOp::Sum => { - let device = self.device().clone(); - let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?; - Ok(Self { slice, device }) - } - crate::op::ReduceOp::Min => Err(CudaError::InternalError("TODO: implement min").into()), - crate::op::ReduceOp::Max => Err(CudaError::InternalError("TODO: implement max").into()), - } + let device = self.device().clone(); + let slice = FastReduce(sum_dims, op).map(&self.slice, &device, layout)?; + Ok(Self { slice, device }) } fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> { diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index 5d9bddee..fe3acc9e 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -1,4 +1,6 @@ #include "compatibility.cuh" +#include +#include // TODO: This is often used to check that the data is contiguous so that // kernels can be easily mapped. However this only returns true for row @@ -140,6 +142,9 @@ __device__ __forceinline__ double absg(double a) { return fabs(a); } __device__ __forceinline__ float copysigng(float a, float b) { return copysignf(a, b); } __device__ __forceinline__ double copysigng(double a, double b) { return copysign(a, b); } +__device__ __forceinline__ uint32_t ming(uint32_t a, uint32_t b) { return min(a, b); } +__device__ __forceinline__ uint32_t maxg(uint32_t a, uint32_t b) { return max(a, b); } + #if __CUDA_ARCH__ >= 530 __device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); } __device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); } diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index afe687bf..34caf12b 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -2,6 +2,7 @@ // https://people.maths.ox.ac.uk/gilesm/cuda/prac4/reduction.pdf #include "cuda_utils.cuh" #include +#include const int BLOCK_SIZE = 1024; @@ -27,7 +28,7 @@ __device__ void fast_sum( size_t tid = threadIdx.x; size_t dst_id = blockIdx.x; - shr[tid] = 0.0; + shr[tid] = 0; // Elements summed in this block range from dst_id * el_to_sum_per_block // to (dst_id + 1) * el_to_sum_per_block. size_t start_idx = dst_id * el_to_sum_per_block; @@ -49,11 +50,113 @@ __device__ void fast_sum( if (tid < s) shr[tid] += shr[tid + s]; } - if (tid == 0) atomicAdd(dst + dst_id, shr[0]); + if (tid == 0) dst[dst_id] = shr[0]; } -#define FAST_SUM_OP(TYPENAME, FN_NAME) \ -extern "C" __global__ void FN_NAME( \ +template +__device__ void fast_max( + const size_t src_numel, + const size_t el_to_sum_per_block, + const size_t num_dims, + const size_t *info, + const T *src, + T *dst +) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + + __shared__ T shr[BLOCK_SIZE]; + size_t tid = threadIdx.x; + size_t dst_id = blockIdx.x; + + shr[tid] = -INFINITY; + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + shr[tid] = maxg(shr[tid], src[strided_i]); + idx += blockDim.x; + } + + // Parallel reduction, see the slides: + // https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf + // https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + __syncthreads(); + if (tid < s) shr[tid] = maxg(shr[tid], shr[tid + s]); + } + + if (tid == 0) dst[dst_id] = shr[0]; +} + +template +__device__ void fast_min( + const size_t src_numel, + const size_t el_to_sum_per_block, + const size_t num_dims, + const size_t *info, + const T *src, + T *dst +) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + + __shared__ T shr[BLOCK_SIZE]; + size_t tid = threadIdx.x; + size_t dst_id = blockIdx.x; + + shr[tid] = INFINITY; + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + shr[tid] = ming(shr[tid], src[strided_i]); + idx += blockDim.x; + } + + // Parallel reduction, see the slides: + // https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf + // https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + __syncthreads(); + if (tid < s) shr[tid] = ming(shr[tid], shr[tid + s]); + } + + if (tid == 0) dst[dst_id] = shr[0]; +} + +#define FAST_OP(TYPENAME, MIN_NAME, MAX_NAME, SUM_NAME) \ +extern "C" __global__ void MIN_NAME( \ + const size_t src_numel, \ + const size_t el_to_sum_per_block, \ + const size_t num_dims, \ + const size_t *info, \ + const TYPENAME *src, \ + TYPENAME *dst \ +) { \ + fast_min(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ +} \ +extern "C" __global__ void MAX_NAME( \ + const size_t src_numel, \ + const size_t el_to_sum_per_block, \ + const size_t num_dims, \ + const size_t *info, \ + const TYPENAME *src, \ + TYPENAME *dst \ +) { \ + fast_max(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ +} \ +extern "C" __global__ void SUM_NAME( \ const size_t src_numel, \ const size_t el_to_sum_per_block, \ const size_t num_dims, \ @@ -106,18 +209,18 @@ extern "C" __global__ void FN_NAME( \ #if __CUDA_ARCH__ >= 800 SUM_OP(__nv_bfloat16, sum_bf16) -FAST_SUM_OP(__nv_bfloat16, fast_sum_bf16) +FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_sum_bf16) #endif #if __CUDA_ARCH__ >= 530 SUM_OP(__half, sum_f16) -FAST_SUM_OP(__half, fast_sum_f16) +FAST_OP(__half, fast_min_f16, fast_max_f16, fast_sum_f16) #endif SUM_OP(float, sum_f32) SUM_OP(double, sum_f64) SUM_OP(uint32_t, sum_u32) -FAST_SUM_OP(float, fast_sum_f32) -FAST_SUM_OP(double, fast_sum_f64) -FAST_SUM_OP(uint32_t, fast_sum_u32) +FAST_OP(float, fast_min_f32, fast_max_f32, fast_sum_f32) +FAST_OP(double, fast_min_f64, fast_max_f64, fast_sum_f64) +FAST_OP(uint32_t, fast_min_u32, fast_max_u32, fast_sum_u32)