add u32 - U32 gather (#2653)

This commit is contained in:
zachcp
2024-11-30 17:18:07 -05:00
committed by GitHub
parent b52c2c6050
commit dba7a9c93e
2 changed files with 81 additions and 79 deletions

View File

@ -1244,6 +1244,7 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F32) => "gather_u32_f32", (DType::U32, DType::F32) => "gather_u32_f32",
(DType::U32, DType::F16) => "gather_u32_f16", (DType::U32, DType::F16) => "gather_u32_f16",
(DType::U32, DType::BF16) => "gather_u32_bf16", (DType::U32, DType::BF16) => "gather_u32_bf16",
(DType::U32, DType::U32) => "gather_u32_u32",
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"), (left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
}; };
let command_buffer = self.device.command_buffer()?; let command_buffer = self.device.command_buffer()?;

View File

@ -17,33 +17,33 @@ METAL_FUNC uint get_strided_index(
} }
template<typename TYPENAME, typename INDEX_TYPENAME> template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void index( METAL_FUNC void index(
constant size_t &dst_size, constant size_t &dst_size,
constant size_t &left_size, constant size_t &left_size,
constant size_t &src_dim_size, constant size_t &src_dim_size,
constant size_t &right_size, constant size_t &right_size,
constant size_t &ids_size, constant size_t &ids_size,
constant bool &contiguous, constant bool &contiguous,
constant size_t *src_dims, constant size_t *src_dims,
constant size_t *src_strides, constant size_t *src_strides,
const device TYPENAME *input, const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids, const device INDEX_TYPENAME *input_ids,
device TYPENAME *output, device TYPENAME *output,
uint tid [[ thread_position_in_grid ]] uint tid [[ thread_position_in_grid ]]
) { ) {
if (tid >= dst_size) { if (tid >= dst_size) {
return; return;
} }
const size_t id_i = (tid / right_size) % ids_size; const size_t id_i = (tid / right_size) % ids_size;
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
const size_t right_rank_i = tid % right_size; const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size; const size_t left_rank_i = tid / right_size / ids_size;
/* /*
// Force prevent out of bounds indexing // Force prevent out of bounds indexing
// since there doesn't seem to be a good way to force crash // since there doesn't seem to be a good way to force crash
// No need to check for zero we're only allowing unsized. // No need to check for zero we're only allowing unsized.
*/ */
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides); const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
output[tid] = input[strided_src_i]; output[tid] = input[strided_src_i];
} }
@ -68,25 +68,25 @@ kernel void NAME( \
template<typename TYPENAME, typename INDEX_TYPENAME> template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void gather( METAL_FUNC void gather(
constant size_t &dst_size, constant size_t &dst_size,
constant size_t &left_size, constant size_t &left_size,
constant size_t &src_dim_size, constant size_t &src_dim_size,
constant size_t &right_size, constant size_t &right_size,
constant size_t &ids_size, constant size_t &ids_size,
const device TYPENAME *input, const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids, const device INDEX_TYPENAME *input_ids,
device TYPENAME *output, device TYPENAME *output,
uint tid [[ thread_position_in_grid ]] uint tid [[ thread_position_in_grid ]]
) { ) {
if (tid >= dst_size) { if (tid >= dst_size) {
return; return;
} }
const INDEX_TYPENAME input_i = input_ids[tid]; const INDEX_TYPENAME input_i = input_ids[tid];
const size_t right_rank_i = tid % right_size; const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size; const size_t left_rank_i = tid / right_size / ids_size;
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i; const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
output[tid] = input[src_i]; output[tid] = input[src_i];
} }
# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \ # define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
@ -105,27 +105,27 @@ kernel void NAME( \
} }
template<typename TYPENAME, typename INDEX_TYPENAME> template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void scatter_add( METAL_FUNC void scatter_add(
constant size_t &dst_size, constant size_t &dst_size,
constant size_t &left_size, constant size_t &left_size,
constant size_t &src_dim_size, constant size_t &src_dim_size,
constant size_t &right_size, constant size_t &right_size,
constant size_t &dst_dim_size, constant size_t &dst_dim_size,
const device TYPENAME *input, const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids, const device INDEX_TYPENAME *input_ids,
device TYPENAME *output, device TYPENAME *output,
uint tid [[ thread_position_in_grid ]] uint tid [[ thread_position_in_grid ]]
) { ) {
if (tid >= dst_size) { if (tid >= dst_size) {
return; return;
} }
const size_t right_rank_i = tid % right_size; const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size; const size_t left_rank_i = tid / right_size;
for (unsigned int j = 0; j < src_dim_size; ++j) { for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i; const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const INDEX_TYPENAME idx = input_ids[src_i]; const INDEX_TYPENAME idx = input_ids[src_i];
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i]; output[dst_i] += input[src_i];
} }
} }
@ -145,28 +145,28 @@ kernel void NAME( \
} }
template<typename TYPENAME, typename INDEX_TYPENAME> template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void index_add( METAL_FUNC void index_add(
constant size_t &dst_size, constant size_t &dst_size,
constant size_t &left_size, constant size_t &left_size,
constant size_t &src_dim_size, constant size_t &src_dim_size,
constant size_t &right_size, constant size_t &right_size,
constant size_t &dst_dim_size, constant size_t &dst_dim_size,
constant size_t &ids_dim_size, constant size_t &ids_dim_size,
const device TYPENAME *input, const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids, const device INDEX_TYPENAME *input_ids,
device TYPENAME *output, device TYPENAME *output,
uint tid [[ thread_position_in_grid ]] uint tid [[ thread_position_in_grid ]]
) { ) {
if (tid >= dst_size) { if (tid >= dst_size) {
return; return;
} }
const size_t right_rank_i = tid % right_size; const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size; const size_t left_rank_i = tid / right_size;
for (unsigned int j = 0; j < ids_dim_size; ++j) { for (unsigned int j = 0; j < ids_dim_size; ++j) {
const INDEX_TYPENAME idx = input_ids[j]; const INDEX_TYPENAME idx = input_ids[j];
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i; const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i]; output[dst_i] += input[src_i];
} }
} }
@ -214,6 +214,7 @@ GATHER_OP(gather_u32_f16, uint, half)
#if defined(__HAVE_BFLOAT__) #if defined(__HAVE_BFLOAT__)
GATHER_OP(gather_u32_bf16, uint, bfloat) GATHER_OP(gather_u32_bf16, uint, bfloat)
#endif #endif
GATHER_OP(gather_u32_u32, uint, uint)
SCATTER_ADD_OP(sa_u32_f32, uint32_t, float) SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float) SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)