mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Support for "unbatched" rope. (#2926)
* Support for (un)-batched rope. * Use 3d rope in the rope/ropei/rope_thd functions. * Get the CPU versions to work. * Fix the cuda version. * Adapt the metal side. * Fix the metal tests.
This commit is contained in:
@ -219,11 +219,15 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
|
||||
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t stride_b) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx >= bh * td) return;
|
||||
|
||||
uint32_t rope_idx = idx % (td / 2);
|
||||
if (stride_b > 0) {
|
||||
uint32_t b_idx = (2 * idx) / stride_b;
|
||||
rope_idx += b_idx * (td / 2);
|
||||
}
|
||||
T c = cos[rope_idx];
|
||||
T s = sin[rope_idx];
|
||||
|
||||
@ -232,7 +236,7 @@ __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, cons
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) {
|
||||
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d, const uint32_t stride_b) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx >= bh * td) return;
|
||||
|
||||
@ -243,6 +247,10 @@ __device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const
|
||||
uint32_t i1 = i_bh * td + i_t * d + i_d;
|
||||
uint32_t i2 = i1 + d / 2;
|
||||
uint32_t i_cs = i_t * (d / 2) + i_d;
|
||||
if (stride_b > 0) {
|
||||
uint32_t b_idx = (2 * idx) / stride_b;
|
||||
i_cs += b_idx * (td / 2);
|
||||
}
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
|
||||
@ -259,7 +267,8 @@ __device__ void rope_thd(
|
||||
const uint32_t b,
|
||||
const uint32_t t,
|
||||
const uint32_t h,
|
||||
const uint32_t d
|
||||
const uint32_t d,
|
||||
const uint32_t stride_b
|
||||
) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx >= b * t * h * d) return;
|
||||
@ -270,6 +279,10 @@ __device__ void rope_thd(
|
||||
uint32_t i1 = i_bth * d + i_d;
|
||||
uint32_t i2 = i1 + d / 2;
|
||||
uint32_t i_cs = i_t * (d / 2) + i_d;
|
||||
if (stride_b > 0) {
|
||||
uint32_t b_idx = (2 * idx) / stride_b;
|
||||
i_cs += b_idx * ((t * d) / 2);
|
||||
}
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
|
||||
@ -546,8 +559,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
const TYPENAME *sin, \
|
||||
TYPENAME *dst, \
|
||||
const uint32_t bh, \
|
||||
const uint32_t td) { \
|
||||
ropei<TYPENAME>(src, cos, sin, dst, bh, td); \
|
||||
const uint32_t td, \
|
||||
const uint32_t stride_b) { \
|
||||
ropei<TYPENAME>(src, cos, sin, dst, bh, td, stride_b); \
|
||||
} \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const TYPENAME *src, \
|
||||
@ -556,8 +570,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
TYPENAME *dst, \
|
||||
const uint32_t bh, \
|
||||
const uint32_t td, \
|
||||
const uint32_t d) { \
|
||||
rope<TYPENAME>(src, cos, sin, dst, bh, td, d); \
|
||||
const uint32_t d, \
|
||||
const uint32_t stride_b) { \
|
||||
rope<TYPENAME>(src, cos, sin, dst, bh, td, d, stride_b); \
|
||||
} \
|
||||
extern "C" __global__ void FN_NAME_THD( \
|
||||
const TYPENAME *src, \
|
||||
@ -567,8 +582,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
const uint32_t b, \
|
||||
const uint32_t t, \
|
||||
const uint32_t h, \
|
||||
const uint32_t d) { \
|
||||
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d); \
|
||||
const uint32_t d, \
|
||||
const uint32_t stride_b) { \
|
||||
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d, stride_b); \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
|
Reference in New Issue
Block a user