mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00

* Fast kernels for rotary embeddings. * Add a test for the fast CPU kernel. * Rope cuda bindings. * Cuda kernel. * Metal kernel (part 1). * Cuda kernels. * Finish the metal kernel. * Use the new kernels in the quantized example. * Fix warning.
461 lines
19 KiB
Plaintext
461 lines
19 KiB
Plaintext
#include "cuda_utils.cuh"
|
|
#include <cmath>
|
|
#include <stdint.h>
|
|
|
|
#define WARP_SIZE 32
|
|
const int BLOCK_SIZE = 1024;
|
|
|
|
// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32
|
|
// but also expect a f32 output so that this can be used for normalization e.g.
|
|
// in softmax.
|
|
|
|
// Fast reduce sum kernel, this assumes that the dimensions to loop over are at
|
|
// the end, each block is responsible for populating one value in the output
|
|
// array. There are at most 1024 threads per block.
|
|
template <typename T>
|
|
__device__ void
|
|
fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
|
|
const size_t num_dims, const size_t *info, const T *src, T *dst) {
|
|
const size_t *dims = info;
|
|
const size_t *strides = info + num_dims;
|
|
|
|
__shared__ T shr[BLOCK_SIZE];
|
|
size_t tid = threadIdx.x;
|
|
size_t dst_id = blockIdx.x;
|
|
|
|
shr[tid] = 0;
|
|
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
|
// to (dst_id + 1) * el_to_sum_per_block.
|
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
|
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
|
size_t idx = start_idx + tid;
|
|
|
|
while (idx < stop_idx) {
|
|
// TODO: Fast version for the contiguous case.
|
|
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
|
shr[tid] += src[strided_i];
|
|
idx += blockDim.x;
|
|
}
|
|
|
|
// Parallel reduction, see the slides:
|
|
// https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf
|
|
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
|
|
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
|
__syncthreads();
|
|
if (tid < s)
|
|
shr[tid] += shr[tid + s];
|
|
}
|
|
|
|
if (tid == 0)
|
|
dst[dst_id] = shr[0];
|
|
}
|
|
|
|
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
|
#pragma unroll
|
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
|
|
}
|
|
return x;
|
|
}
|
|
|
|
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
|
|
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
|
|
template <typename T>
|
|
__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const float eps) {
|
|
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
|
const int tid = threadIdx.x;
|
|
const int block_size = blockDim.x;
|
|
|
|
float tmp = 0.0f; // partial sum for thread in warp
|
|
|
|
for (int col = tid; col < ncols; col += block_size) {
|
|
const float xi = static_cast<float>(x[row*ncols + col]);
|
|
tmp += xi * xi;
|
|
}
|
|
|
|
// sum up partial sums
|
|
tmp = warp_reduce_sum(tmp);
|
|
if (block_size > WARP_SIZE) {
|
|
__shared__ float s_sum[32];
|
|
int warp_id = threadIdx.x / WARP_SIZE;
|
|
int lane_id = threadIdx.x % WARP_SIZE;
|
|
if (lane_id == 0) {
|
|
s_sum[warp_id] = tmp;
|
|
}
|
|
__syncthreads();
|
|
tmp = s_sum[lane_id];
|
|
tmp = warp_reduce_sum(tmp);
|
|
}
|
|
|
|
const float mean = tmp / ncols;
|
|
const float scale = rsqrtf(mean + eps);
|
|
|
|
if (alpha == nullptr) {
|
|
for (int col = tid; col < ncols; col += block_size) {
|
|
dst[row*ncols + col] = static_cast<T>(scale * static_cast<float>(x[row*ncols + col]));
|
|
}
|
|
}
|
|
else {
|
|
for (int col = tid; col < ncols; col += block_size) {
|
|
float a = static_cast<float>(alpha[col]);
|
|
dst[row*ncols + col] = static_cast<T>(scale * static_cast<float>(x[row*ncols + col]) * a);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Softmax implementation adapted from ggml.
|
|
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L4159
|
|
template <typename T, typename ACC>
|
|
__device__ void softmax(const T * x, T * dst, const int ncols) {
|
|
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
|
const int block_size = blockDim.y;
|
|
const int tid = threadIdx.y;
|
|
|
|
T max_val = -INFINITY;
|
|
|
|
for (int col = tid; col < ncols; col += block_size) {
|
|
const int i = row*ncols + col;
|
|
max_val = maxg(max_val, x[i]);
|
|
}
|
|
|
|
// find the max value in the block
|
|
#pragma unroll
|
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
max_val = maxg(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
|
|
}
|
|
|
|
ACC tmp = 0.;
|
|
|
|
for (int col = tid; col < ncols; col += block_size) {
|
|
const int i = row*ncols + col;
|
|
const T val = expg(x[i] - max_val);
|
|
tmp += static_cast<ACC>(val);
|
|
dst[i] = val;
|
|
}
|
|
|
|
// sum up partial sums
|
|
#pragma unroll
|
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
|
}
|
|
|
|
const ACC inv_tmp = 1. / tmp;
|
|
|
|
for (int col = tid; col < ncols; col += block_size) {
|
|
const int i = row*ncols + col;
|
|
dst[i] *= inv_tmp;
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
|
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (2 * idx > bh * td) return;
|
|
|
|
uint32_t rope_idx = idx % (td / 2);
|
|
T c = cos[rope_idx];
|
|
T s = sin[rope_idx];
|
|
|
|
dst[2 * idx] = src[2 * idx] * c - src[2 * idx + 1] * s;
|
|
dst[2 * idx + 1] = src[2 * idx] * s + src[2 * idx + 1] * c;
|
|
}
|
|
|
|
|
|
template <typename T>
|
|
__device__ void
|
|
fast_max(const size_t src_numel, const size_t el_to_sum_per_block,
|
|
const size_t num_dims, const size_t *info, const T *src, T *dst) {
|
|
const size_t *dims = info;
|
|
const size_t *strides = info + num_dims;
|
|
|
|
__shared__ T shr[BLOCK_SIZE];
|
|
size_t tid = threadIdx.x;
|
|
size_t dst_id = blockIdx.x;
|
|
|
|
shr[tid] = -INFINITY;
|
|
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
|
// to (dst_id + 1) * el_to_sum_per_block.
|
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
|
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
|
size_t idx = start_idx + tid;
|
|
|
|
while (idx < stop_idx) {
|
|
// TODO: Fast version for the contiguous case.
|
|
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
|
shr[tid] = maxg(shr[tid], src[strided_i]);
|
|
idx += blockDim.x;
|
|
}
|
|
|
|
// Parallel reduction, see the slides:
|
|
// https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf
|
|
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
|
|
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
|
__syncthreads();
|
|
if (tid < s)
|
|
shr[tid] = maxg(shr[tid], shr[tid + s]);
|
|
}
|
|
|
|
if (tid == 0)
|
|
dst[dst_id] = shr[0];
|
|
}
|
|
|
|
template <typename T>
|
|
__device__ void
|
|
fast_min(const size_t src_numel, const size_t el_to_sum_per_block,
|
|
const size_t num_dims, const size_t *info, const T *src, T *dst) {
|
|
const size_t *dims = info;
|
|
const size_t *strides = info + num_dims;
|
|
|
|
__shared__ T shr[BLOCK_SIZE];
|
|
size_t tid = threadIdx.x;
|
|
size_t dst_id = blockIdx.x;
|
|
|
|
shr[tid] = INFINITY;
|
|
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
|
// to (dst_id + 1) * el_to_sum_per_block.
|
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
|
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
|
size_t idx = start_idx + tid;
|
|
|
|
while (idx < stop_idx) {
|
|
// TODO: Fast version for the contiguous case.
|
|
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
|
shr[tid] = ming(shr[tid], src[strided_i]);
|
|
idx += blockDim.x;
|
|
}
|
|
|
|
// Parallel reduction, see the slides:
|
|
// https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf
|
|
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
|
|
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
|
__syncthreads();
|
|
if (tid < s)
|
|
shr[tid] = ming(shr[tid], shr[tid + s]);
|
|
}
|
|
|
|
if (tid == 0)
|
|
dst[dst_id] = shr[0];
|
|
}
|
|
|
|
template <typename T>
|
|
__device__ void
|
|
fast_argmin(const size_t src_numel, const size_t el_to_sum_per_block,
|
|
const size_t num_dims, const size_t *info, const T *src, uint32_t *dst) {
|
|
const size_t *dims = info;
|
|
const size_t *strides = info + num_dims;
|
|
|
|
__shared__ T shr[BLOCK_SIZE];
|
|
__shared__ uint32_t shr_index[BLOCK_SIZE];
|
|
size_t tid = threadIdx.x;
|
|
size_t dst_id = blockIdx.x;
|
|
|
|
// Not sure how that works on uint32_t and uint8_t but it seems to do ok.
|
|
shr[tid] = INFINITY;
|
|
shr_index[tid] = 0xFFFFFFFF;
|
|
bool not_set = true;
|
|
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
|
// to (dst_id + 1) * el_to_sum_per_block.
|
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
|
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
|
size_t idx = start_idx + tid;
|
|
|
|
while (idx < stop_idx) {
|
|
// TODO: Fast version for the contiguous case.
|
|
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
|
if (not_set || src[strided_i] < shr[tid]) {
|
|
shr[tid] = src[strided_i];
|
|
// Assume that the reduction takes place over the last dimension which is contiguous.
|
|
shr_index[tid] = idx % dims[num_dims - 1];
|
|
not_set = false;
|
|
}
|
|
idx += blockDim.x;
|
|
}
|
|
|
|
// Parallel reduction, see the slides:
|
|
// https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf
|
|
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
|
|
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
|
__syncthreads();
|
|
if (tid < s && shr[tid + s] < shr[tid]) {
|
|
shr[tid] = shr[tid + s];
|
|
shr_index[tid] = shr_index[tid + s];
|
|
}
|
|
}
|
|
|
|
if (tid == 0)
|
|
dst[dst_id] = shr_index[0];
|
|
}
|
|
|
|
template <typename T>
|
|
__device__ void
|
|
fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
|
const size_t num_dims, const size_t *info, const T *src, uint32_t *dst) {
|
|
const size_t *dims = info;
|
|
const size_t *strides = info + num_dims;
|
|
|
|
__shared__ T shr[BLOCK_SIZE];
|
|
__shared__ uint32_t shr_index[BLOCK_SIZE];
|
|
size_t tid = threadIdx.x;
|
|
size_t dst_id = blockIdx.x;
|
|
|
|
shr[tid] = -INFINITY;
|
|
shr_index[tid] = 0xFFFFFFFF;
|
|
bool not_set = true;
|
|
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
|
// to (dst_id + 1) * el_to_sum_per_block.
|
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
|
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
|
size_t idx = start_idx + tid;
|
|
|
|
while (idx < stop_idx) {
|
|
// TODO: Fast version for the contiguous case.
|
|
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
|
if (not_set || src[strided_i] > shr[tid]) {
|
|
shr[tid] = src[strided_i];
|
|
// Assume that the reduction takes place over the last dimension which is contiguous.
|
|
shr_index[tid] = idx % dims[num_dims - 1];
|
|
not_set = false;
|
|
}
|
|
idx += blockDim.x;
|
|
}
|
|
|
|
// Parallel reduction, see the slides:
|
|
// https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf
|
|
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
|
|
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
|
__syncthreads();
|
|
if (tid < s && shr[tid + s] > shr[tid]) {
|
|
shr[tid] = shr[tid + s];
|
|
shr_index[tid] = shr_index[tid + s];
|
|
}
|
|
}
|
|
|
|
if (tid == 0)
|
|
dst[dst_id] = shr_index[0];
|
|
}
|
|
|
|
#define FAST_OP(TYPENAME, MIN_NAME, MAX_NAME, ARGMIN_NAME, ARGMAX_NAME, SUM_NAME) \
|
|
extern "C" __global__ void ARGMIN_NAME( \
|
|
const size_t src_numel, const size_t el_to_sum_per_block, \
|
|
const size_t num_dims, const size_t *info, const TYPENAME *src, \
|
|
uint32_t *dst) { \
|
|
fast_argmin(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
|
|
} \
|
|
extern "C" __global__ void ARGMAX_NAME( \
|
|
const size_t src_numel, const size_t el_to_sum_per_block, \
|
|
const size_t num_dims, const size_t *info, const TYPENAME *src, \
|
|
uint32_t *dst) { \
|
|
fast_argmax(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
|
|
} \
|
|
extern "C" __global__ void MIN_NAME( \
|
|
const size_t src_numel, const size_t el_to_sum_per_block, \
|
|
const size_t num_dims, const size_t *info, const TYPENAME *src, \
|
|
TYPENAME *dst) { \
|
|
fast_min(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
|
|
} \
|
|
extern "C" __global__ void MAX_NAME( \
|
|
const size_t src_numel, const size_t el_to_sum_per_block, \
|
|
const size_t num_dims, const size_t *info, const TYPENAME *src, \
|
|
TYPENAME *dst) { \
|
|
fast_max(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
|
|
} \
|
|
extern "C" __global__ void SUM_NAME( \
|
|
const size_t src_numel, const size_t el_to_sum_per_block, \
|
|
const size_t num_dims, const size_t *info, const TYPENAME *src, \
|
|
TYPENAME *dst) { \
|
|
fast_sum(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
|
|
}
|
|
|
|
#define SUM_OP(TYPENAME, FN_NAME) \
|
|
extern "C" __global__ void FN_NAME( \
|
|
const size_t numel, const size_t num_dims, const size_t num_sum_dims, \
|
|
const size_t *info, const TYPENAME *inp, TYPENAME *out) { \
|
|
const size_t *dims = info; \
|
|
const size_t *strides = info + num_dims; \
|
|
const size_t *sum_dims_l = info + 2 * num_dims; \
|
|
const size_t *sum_dims_s = info + 2 * num_dims + num_sum_dims; \
|
|
if (is_contiguous(num_dims, dims, strides)) { \
|
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; \
|
|
i += blockDim.x * gridDim.x) { \
|
|
size_t dst_index = i; \
|
|
for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \
|
|
size_t stride = sum_dims_s[nd]; \
|
|
size_t pre = dst_index / stride; \
|
|
size_t post = dst_index % stride; \
|
|
dst_index = (pre / sum_dims_l[nd]) * stride + post; \
|
|
} \
|
|
atomicAdd(out + dst_index, inp[i]); \
|
|
} \
|
|
} else { \
|
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; \
|
|
i += blockDim.x * gridDim.x) { \
|
|
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
|
size_t dst_index = i; \
|
|
for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \
|
|
size_t stride = sum_dims_s[nd]; \
|
|
size_t pre = dst_index / stride; \
|
|
size_t post = dst_index % stride; \
|
|
dst_index = (pre / sum_dims_l[nd]) * stride + post; \
|
|
} \
|
|
atomicAdd(out + dst_index, inp[strided_i]); \
|
|
} \
|
|
} \
|
|
}
|
|
|
|
#define SOFTMAX_OP(TYPENAME, ACC_TYPENAME, FN_NAME) \
|
|
extern "C" __global__ void FN_NAME( \
|
|
const TYPENAME *src, TYPENAME *dst, \
|
|
const int n_cols) { \
|
|
softmax<TYPENAME, ACC_TYPENAME>(src, dst, n_cols); \
|
|
} \
|
|
|
|
#define RMSNORM_OP(TYPENAME, FN_NAME) \
|
|
extern "C" __global__ void FN_NAME( \
|
|
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
|
|
const int n_cols, const float eps) { \
|
|
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
|
|
} \
|
|
|
|
#define ROPEI_OP(TYPENAME, FN_NAME) \
|
|
extern "C" __global__ void FN_NAME( \
|
|
const TYPENAME *src, \
|
|
const TYPENAME *cos, \
|
|
const TYPENAME *sin, \
|
|
TYPENAME *dst, \
|
|
const uint32_t bh, \
|
|
const uint32_t td) { \
|
|
ropei<TYPENAME>(src, cos, sin, dst, bh, td); \
|
|
} \
|
|
|
|
#if __CUDA_ARCH__ >= 800
|
|
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
|
|
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
|
|
ROPEI_OP(__nv_bfloat16, rope_i_bf16)
|
|
SUM_OP(__nv_bfloat16, sum_bf16)
|
|
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
|
|
#endif
|
|
|
|
#if __CUDA_ARCH__ >= 530
|
|
SOFTMAX_OP(__half, float, softmax_f16)
|
|
RMSNORM_OP(__half, rmsnorm_f16)
|
|
ROPEI_OP(__half, rope_i_f16)
|
|
SUM_OP(__half, sum_f16)
|
|
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
|
|
#endif
|
|
|
|
SUM_OP(float, sum_f32)
|
|
SUM_OP(double, sum_f64)
|
|
SUM_OP(uint32_t, sum_u32)
|
|
SOFTMAX_OP(float, float, softmax_f32)
|
|
SOFTMAX_OP(double, double, softmax_f64)
|
|
RMSNORM_OP(float, rmsnorm_f32)
|
|
RMSNORM_OP(double, rmsnorm_f64)
|
|
ROPEI_OP(float, rope_i_f32)
|
|
ROPEI_OP(double, rope_i_f64)
|
|
|
|
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
|
|
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)
|
|
FAST_OP(uint32_t, fast_min_u32, fast_max_u32, fast_argmin_u32, fast_argmax_u32, fast_sum_u32)
|
|
FAST_OP(int64_t, fast_min_i64, fast_max_i64, fast_argmin_i64, fast_argmax_i64, fast_sum_i64)
|
|
FAST_OP(uint8_t, fast_min_u8, fast_max_u8, fast_argmin_u8, fast_argmax_u8, fast_sum_u8)
|