mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Fast kernels for rotary embeddings. (#1928)
* Fast kernels for rotary embeddings. * Add a test for the fast CPU kernel. * Rope cuda bindings. * Cuda kernel. * Metal kernel (part 1). * Cuda kernels. * Finish the metal kernel. * Use the new kernels in the quantized example. * Fix warning.
This commit is contained in:
@ -147,6 +147,20 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx > bh * td) return;
|
||||
|
||||
uint32_t rope_idx = idx % (td / 2);
|
||||
T c = cos[rope_idx];
|
||||
T s = sin[rope_idx];
|
||||
|
||||
dst[2 * idx] = src[2 * idx] * c - src[2 * idx + 1] * s;
|
||||
dst[2 * idx + 1] = src[2 * idx] * s + src[2 * idx + 1] * c;
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
__device__ void
|
||||
fast_max(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
@ -402,9 +416,21 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
|
||||
} \
|
||||
|
||||
#define ROPEI_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *cos, \
|
||||
const TYPENAME *sin, \
|
||||
TYPENAME *dst, \
|
||||
const uint32_t bh, \
|
||||
const uint32_t td) { \
|
||||
ropei<TYPENAME>(src, cos, sin, dst, bh, td); \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
|
||||
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
|
||||
ROPEI_OP(__nv_bfloat16, rope_i_bf16)
|
||||
SUM_OP(__nv_bfloat16, sum_bf16)
|
||||
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
|
||||
#endif
|
||||
@ -412,6 +438,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
SOFTMAX_OP(__half, float, softmax_f16)
|
||||
RMSNORM_OP(__half, rmsnorm_f16)
|
||||
ROPEI_OP(__half, rope_i_f16)
|
||||
SUM_OP(__half, sum_f16)
|
||||
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
|
||||
#endif
|
||||
@ -423,6 +450,8 @@ SOFTMAX_OP(float, float, softmax_f32)
|
||||
SOFTMAX_OP(double, double, softmax_f64)
|
||||
RMSNORM_OP(float, rmsnorm_f32)
|
||||
RMSNORM_OP(double, rmsnorm_f64)
|
||||
ROPEI_OP(float, rope_i_f32)
|
||||
ROPEI_OP(double, rope_i_f64)
|
||||
|
||||
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
|
||||
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)
|
||||
|
Reference in New Issue
Block a user