Support for "unbatched" rope. (#2926)

* Support for (un)-batched rope.

* Use 3d rope in the rope/ropei/rope_thd functions.

* Get the CPU versions to work.

* Fix the cuda version.

* Adapt the metal side.

* Fix the metal tests.
This commit is contained in:
Laurent Mazare
2025-04-27 15:12:02 +02:00
committed by GitHub
parent 6e0646c208
commit e3db30021f
6 changed files with 217 additions and 33 deletions

View File

@ -219,11 +219,15 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
}
template <typename T>
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t stride_b) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (2 * idx >= bh * td) return;
uint32_t rope_idx = idx % (td / 2);
if (stride_b > 0) {
uint32_t b_idx = (2 * idx) / stride_b;
rope_idx += b_idx * (td / 2);
}
T c = cos[rope_idx];
T s = sin[rope_idx];
@ -232,7 +236,7 @@ __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, cons
}
template <typename T>
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) {
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d, const uint32_t stride_b) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (2 * idx >= bh * td) return;
@ -243,6 +247,10 @@ __device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const
uint32_t i1 = i_bh * td + i_t * d + i_d;
uint32_t i2 = i1 + d / 2;
uint32_t i_cs = i_t * (d / 2) + i_d;
if (stride_b > 0) {
uint32_t b_idx = (2 * idx) / stride_b;
i_cs += b_idx * (td / 2);
}
T c = cos[i_cs];
T s = sin[i_cs];
@ -259,7 +267,8 @@ __device__ void rope_thd(
const uint32_t b,
const uint32_t t,
const uint32_t h,
const uint32_t d
const uint32_t d,
const uint32_t stride_b
) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (2 * idx >= b * t * h * d) return;
@ -270,6 +279,10 @@ __device__ void rope_thd(
uint32_t i1 = i_bth * d + i_d;
uint32_t i2 = i1 + d / 2;
uint32_t i_cs = i_t * (d / 2) + i_d;
if (stride_b > 0) {
uint32_t b_idx = (2 * idx) / stride_b;
i_cs += b_idx * ((t * d) / 2);
}
T c = cos[i_cs];
T s = sin[i_cs];
@ -546,8 +559,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
const TYPENAME *sin, \
TYPENAME *dst, \
const uint32_t bh, \
const uint32_t td) { \
ropei<TYPENAME>(src, cos, sin, dst, bh, td); \
const uint32_t td, \
const uint32_t stride_b) { \
ropei<TYPENAME>(src, cos, sin, dst, bh, td, stride_b); \
} \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, \
@ -556,8 +570,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
TYPENAME *dst, \
const uint32_t bh, \
const uint32_t td, \
const uint32_t d) { \
rope<TYPENAME>(src, cos, sin, dst, bh, td, d); \
const uint32_t d, \
const uint32_t stride_b) { \
rope<TYPENAME>(src, cos, sin, dst, bh, td, d, stride_b); \
} \
extern "C" __global__ void FN_NAME_THD( \
const TYPENAME *src, \
@ -567,8 +582,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
const uint32_t b, \
const uint32_t t, \
const uint32_t h, \
const uint32_t d) { \
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d); \
const uint32_t d, \
const uint32_t stride_b) { \
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d, stride_b); \
} \
#if __CUDA_ARCH__ >= 800