Add the scatter op. (#2921)

* Add the scatter op.

* Backprop support.

* Cuda support.
This commit is contained in:
Laurent Mazare
2025-04-25 21:46:58 +02:00
committed by GitHub
parent 3aeb9575c7
commit 3827685524
15 changed files with 429 additions and 19 deletions

View File

@ -104,6 +104,31 @@ kernel void NAME( \
gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
}
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void scatter(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &dst_dim_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size;
for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const INDEX_TYPENAME idx = input_ids[src_i];
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] = input[src_i];
}
}
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void scatter_add(
constant size_t &dst_size,
@ -129,6 +154,21 @@ METAL_FUNC void scatter_add(
}
}
# define SCATTER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
constant size_t &left_size, \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &dst_dim_size, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
scatter<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \
}
# define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
@ -235,6 +275,19 @@ SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat)
SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
#endif
SCATTER_OP(s_u32_f32, uint32_t, float)
SCATTER_OP(s_u8_f32, uint8_t, float)
SCATTER_OP(s_i64_f32, int64_t, float)
SCATTER_OP(s_u32_u32, uint32_t, uint32_t)
SCATTER_OP(s_u32_f16, uint32_t, half)
SCATTER_OP(s_u8_f16, uint8_t, half)
SCATTER_OP(s_i64_f16, int64_t, half)
#if defined(__HAVE_BFLOAT__)
SCATTER_OP(s_u32_bf16, uint32_t, bfloat)
SCATTER_OP(s_u8_bf16, uint8_t, bfloat)
SCATTER_OP(s_i64_bf16, int64_t, bfloat)
#endif
// i64
INDEX_ADD_OP(ia_i64_f16, int64_t, half)
INDEX_ADD_OP(ia_i64_f32, int64_t, float)

View File

@ -1447,7 +1447,7 @@ pub fn call_gather(
}
#[allow(clippy::too_many_arguments)]
pub fn call_scatter_add(
pub fn call_scatter(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,

View File

@ -1574,7 +1574,7 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
let input_buffer = new_buffer(&device, input);
let ids_buffer = new_buffer(&device, ids);
let output = device.new_buffer(std::mem::size_of_val(input) as u64, options);
call_scatter_add(
call_scatter(
&device,
command_buffer,
&kernels,