Add the rope THD kernel. (#2014)

* Add the rope THD kernel.

* Cuda kernel for rope-thd.

* Add the metal kernels.

* Add a dedicated test.
This commit is contained in:
Laurent Mazare
2024-04-05 08:32:58 +02:00
committed by GitHub
parent ace282e5c2
commit 2ac302a5d1
6 changed files with 400 additions and 31 deletions

View File

@ -179,6 +179,33 @@ __device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const
dst[i2] = src[i1] * s + src[i2] * c;
}
template <typename T>
__device__ void rope_thd(
const T * src,
const T * cos,
const T * sin,
T * dst,
const uint32_t b,
const uint32_t t,
const uint32_t h,
const uint32_t d
) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (2 * idx >= b * t * h * d) return;
uint32_t i_bth = idx / (d / 2);
uint32_t i_d = idx - (d / 2) * i_bth;
uint32_t i_t = (i_bth / h) % t;
uint32_t i1 = i_bth * d + i_d;
uint32_t i2 = i1 + d / 2;
uint32_t i_cs = i_t * (d / 2) + i_d;
T c = cos[i_cs];
T s = sin[i_cs];
dst[i1] = src[i1] * c - src[i2] * s;
dst[i2] = src[i1] * s + src[i2] * c;
}
template <typename T>
__device__ void
fast_max(const size_t src_numel, const size_t el_to_sum_per_block,
@ -434,7 +461,7 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
} \
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I) \
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \
extern "C" __global__ void FN_NAME_I( \
const TYPENAME *src, \
const TYPENAME *cos, \
@ -454,11 +481,22 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
const uint32_t d) { \
rope<TYPENAME>(src, cos, sin, dst, bh, td, d); \
} \
extern "C" __global__ void FN_NAME_THD( \
const TYPENAME *src, \
const TYPENAME *cos, \
const TYPENAME *sin, \
TYPENAME *dst, \
const uint32_t b, \
const uint32_t t, \
const uint32_t h, \
const uint32_t d) { \
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d); \
} \
#if __CUDA_ARCH__ >= 800
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16)
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
SUM_OP(__nv_bfloat16, sum_bf16)
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
#endif
@ -466,7 +504,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
#if __CUDA_ARCH__ >= 530
SOFTMAX_OP(__half, float, softmax_f16)
RMSNORM_OP(__half, rmsnorm_f16)
ROPE_OP(__half, rope_f16, rope_i_f16)
ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16)
SUM_OP(__half, sum_f16)
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
#endif
@ -478,8 +516,8 @@ SOFTMAX_OP(float, float, softmax_f32)
SOFTMAX_OP(double, double, softmax_f64)
RMSNORM_OP(float, rmsnorm_f32)
RMSNORM_OP(double, rmsnorm_f64)
ROPE_OP(float, rope_f32, rope_i_f32)
ROPE_OP(double, rope_f64, rope_i_f64)
ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32)
ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64)
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)