mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add the scatter op. (#2921)
* Add the scatter op. * Backprop support. * Cuda support.
This commit is contained in:
@ -104,6 +104,31 @@ kernel void NAME( \
|
||||
gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
|
||||
}
|
||||
|
||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||
METAL_FUNC void scatter(
|
||||
constant size_t &dst_size,
|
||||
constant size_t &left_size,
|
||||
constant size_t &src_dim_size,
|
||||
constant size_t &right_size,
|
||||
constant size_t &dst_dim_size,
|
||||
const device TYPENAME *input,
|
||||
const device INDEX_TYPENAME *input_ids,
|
||||
device TYPENAME *output,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
if (tid >= dst_size) {
|
||||
return;
|
||||
}
|
||||
const size_t right_rank_i = tid % right_size;
|
||||
const size_t left_rank_i = tid / right_size;
|
||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||
const INDEX_TYPENAME idx = input_ids[src_i];
|
||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||
output[dst_i] = input[src_i];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||
METAL_FUNC void scatter_add(
|
||||
constant size_t &dst_size,
|
||||
@ -129,6 +154,21 @@ METAL_FUNC void scatter_add(
|
||||
}
|
||||
}
|
||||
|
||||
# define SCATTER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &dst_size, \
|
||||
constant size_t &left_size, \
|
||||
constant size_t &src_dim_size, \
|
||||
constant size_t &right_size, \
|
||||
constant size_t &dst_dim_size, \
|
||||
const device TYPENAME *input, \
|
||||
const device INDEX_TYPENAME *input_ids, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
scatter<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \
|
||||
}
|
||||
|
||||
# define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &dst_size, \
|
||||
@ -235,6 +275,19 @@ SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat)
|
||||
SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
|
||||
#endif
|
||||
|
||||
SCATTER_OP(s_u32_f32, uint32_t, float)
|
||||
SCATTER_OP(s_u8_f32, uint8_t, float)
|
||||
SCATTER_OP(s_i64_f32, int64_t, float)
|
||||
SCATTER_OP(s_u32_u32, uint32_t, uint32_t)
|
||||
SCATTER_OP(s_u32_f16, uint32_t, half)
|
||||
SCATTER_OP(s_u8_f16, uint8_t, half)
|
||||
SCATTER_OP(s_i64_f16, int64_t, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
SCATTER_OP(s_u32_bf16, uint32_t, bfloat)
|
||||
SCATTER_OP(s_u8_bf16, uint8_t, bfloat)
|
||||
SCATTER_OP(s_i64_bf16, int64_t, bfloat)
|
||||
#endif
|
||||
|
||||
// i64
|
||||
INDEX_ADD_OP(ia_i64_f16, int64_t, half)
|
||||
INDEX_ADD_OP(ia_i64_f32, int64_t, float)
|
||||
|
@ -1447,7 +1447,7 @@ pub fn call_gather(
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_scatter_add(
|
||||
pub fn call_scatter(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
|
@ -1574,7 +1574,7 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
let input_buffer = new_buffer(&device, input);
|
||||
let ids_buffer = new_buffer(&device, ids);
|
||||
let output = device.new_buffer(std::mem::size_of_val(input) as u64, options);
|
||||
call_scatter_add(
|
||||
call_scatter(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
|
Reference in New Issue
Block a user