Add the scatter op. (#2921)

* Add the scatter op.

* Backprop support.

* Cuda support.
This commit is contained in:
Laurent Mazare
2025-04-25 21:46:58 +02:00
committed by GitHub
parent 3aeb9575c7
commit 3827685524
15 changed files with 429 additions and 19 deletions

View File

@ -114,6 +114,30 @@ extern "C" __global__ void FN_NAME( \
const size_t right_size \
) { index_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
template<typename T, typename I>
__device__ void scatter(
const I *ids,
const T *inp,
T *out,
const size_t left_size,
const size_t src_dim_size,
const size_t dst_dim_size,
const size_t right_size
) {
const size_t numel = left_size * right_size;
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
const size_t pre = i / right_size;
const size_t post = i % right_size;
for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
const size_t idx = ids[src_i];
assert(idx < dst_dim_size);
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] = inp[src_i];
}
}
}
template<typename T, typename I>
__device__ void scatter_add(
const I *ids,
@ -138,6 +162,17 @@ __device__ void scatter_add(
}
}
#define S_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const INDEX_TYPENAME *ids, \
const TYPENAME *inp, \
TYPENAME *out, \
const size_t left_size, \
const size_t src_dim_size, \
const size_t dst_dim_size, \
const size_t right_size \
) { scatter(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
#define SA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const INDEX_TYPENAME *ids, \
@ -163,6 +198,9 @@ IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16)
SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16)
SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16)
SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16)
S_OP(__nv_bfloat16, int64_t, s_i64_bf16)
S_OP(__nv_bfloat16, uint32_t, s_u32_bf16)
S_OP(__nv_bfloat16, uint8_t, s_u8_bf16)
#endif
#if __CUDA_ARCH__ >= 530
@ -178,6 +216,9 @@ IA_OP(__half, uint8_t, ia_u8_f16)
SA_OP(__half, int64_t, sa_i64_f16)
SA_OP(__half, uint32_t, sa_u32_f16)
SA_OP(__half, uint8_t, sa_u8_f16)
S_OP(__half, int64_t, s_i64_f16)
S_OP(__half, uint32_t, s_u32_f16)
S_OP(__half, uint8_t, s_u8_f16)
#endif
IS_OP(float, int64_t, is_i64_f32)
@ -251,3 +292,21 @@ SA_OP(double, uint8_t, sa_u8_f64)
SA_OP(uint8_t, uint8_t, sa_u8_u8)
SA_OP(uint32_t, uint8_t, sa_u8_u32)
SA_OP(int64_t, uint8_t, sa_u8_i64)
S_OP(float, int64_t, s_i64_f32)
S_OP(double, int64_t, s_i64_f64)
S_OP(uint8_t, int64_t, s_i64_u8)
S_OP(int64_t, int64_t, s_i64_i64)
S_OP(uint32_t, int64_t, s_i64_u32)
S_OP(float, uint32_t, s_u32_f32)
S_OP(double, uint32_t, s_u32_f64)
S_OP(uint8_t, uint32_t, s_u32_u8)
S_OP(int64_t, uint32_t, s_u32_i64)
S_OP(uint32_t, uint32_t, s_u32_u32)
S_OP(float, uint8_t, s_u8_f32)
S_OP(double, uint8_t, s_u8_f64)
S_OP(uint8_t, uint8_t, s_u8_u8)
S_OP(uint32_t, uint8_t, s_u8_u32)
S_OP(int64_t, uint8_t, s_u8_i64)