diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 271502c5..fca6865e 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -49,6 +49,50 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block, dst[dst_id] = shr[0]; } +// Softmax implementation adapted from ggml. +// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L4159 +template +__device__ void softmax(const T * x, T * dst, const int ncols) { + const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int block_size = blockDim.y; + const int tid = threadIdx.y; + + T max_val = -INFINITY; + + for (int col = tid; col < ncols; col += block_size) { + const int i = row*ncols + col; + max_val = maxg(max_val, x[i]); + } + + // find the max value in the block +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + max_val = maxg(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32)); + } + + ACC tmp = 0.; + + for (int col = tid; col < ncols; col += block_size) { + const int i = row*ncols + col; + const T val = expg(x[i] - max_val); + tmp += static_cast(val); + dst[i] = val; + } + + // sum up partial sums +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + const ACC inv_tmp = 1. / tmp; + + for (int col = tid; col < ncols; col += block_size) { + const int i = row*ncols + col; + dst[i] *= inv_tmp; + } +} + template __device__ void fast_max(const size_t src_numel, const size_t el_to_sum_per_block, @@ -290,12 +334,21 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, } \ } +#define SOFTMAX_OP(TYPENAME, ACC_TYPENAME, FN_NAME) \ + extern "C" __global__ void FN_NAME( \ + const TYPENAME *src, TYPENAME *dst, \ + const int n_cols) { \ + softmax(src, dst, n_cols); \ + } \ + #if __CUDA_ARCH__ >= 800 +SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16) SUM_OP(__nv_bfloat16, sum_bf16) FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16) #endif #if __CUDA_ARCH__ >= 530 +SOFTMAX_OP(__half, float, softmax_f16) SUM_OP(__half, sum_f16) FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16) #endif @@ -303,6 +356,8 @@ FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fa SUM_OP(float, sum_f32) SUM_OP(double, sum_f64) SUM_OP(uint32_t, sum_u32) +SOFTMAX_OP(float, float, softmax_f32) +SOFTMAX_OP(double, double, softmax_f64) FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32) FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)