mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
add u32 - U32 gather (#2653)
This commit is contained in:
@ -1244,6 +1244,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
(DType::U32, DType::F32) => "gather_u32_f32",
|
(DType::U32, DType::F32) => "gather_u32_f32",
|
||||||
(DType::U32, DType::F16) => "gather_u32_f16",
|
(DType::U32, DType::F16) => "gather_u32_f16",
|
||||||
(DType::U32, DType::BF16) => "gather_u32_bf16",
|
(DType::U32, DType::BF16) => "gather_u32_bf16",
|
||||||
|
(DType::U32, DType::U32) => "gather_u32_u32",
|
||||||
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
|
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
|
||||||
};
|
};
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
|
@ -17,33 +17,33 @@ METAL_FUNC uint get_strided_index(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||||
METAL_FUNC void index(
|
METAL_FUNC void index(
|
||||||
constant size_t &dst_size,
|
constant size_t &dst_size,
|
||||||
constant size_t &left_size,
|
constant size_t &left_size,
|
||||||
constant size_t &src_dim_size,
|
constant size_t &src_dim_size,
|
||||||
constant size_t &right_size,
|
constant size_t &right_size,
|
||||||
constant size_t &ids_size,
|
constant size_t &ids_size,
|
||||||
constant bool &contiguous,
|
constant bool &contiguous,
|
||||||
constant size_t *src_dims,
|
constant size_t *src_dims,
|
||||||
constant size_t *src_strides,
|
constant size_t *src_strides,
|
||||||
const device TYPENAME *input,
|
const device TYPENAME *input,
|
||||||
const device INDEX_TYPENAME *input_ids,
|
const device INDEX_TYPENAME *input_ids,
|
||||||
device TYPENAME *output,
|
device TYPENAME *output,
|
||||||
uint tid [[ thread_position_in_grid ]]
|
uint tid [[ thread_position_in_grid ]]
|
||||||
) {
|
) {
|
||||||
if (tid >= dst_size) {
|
if (tid >= dst_size) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const size_t id_i = (tid / right_size) % ids_size;
|
const size_t id_i = (tid / right_size) % ids_size;
|
||||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
|
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
|
||||||
const size_t right_rank_i = tid % right_size;
|
const size_t right_rank_i = tid % right_size;
|
||||||
const size_t left_rank_i = tid / right_size / ids_size;
|
const size_t left_rank_i = tid / right_size / ids_size;
|
||||||
/*
|
/*
|
||||||
// Force prevent out of bounds indexing
|
// Force prevent out of bounds indexing
|
||||||
// since there doesn't seem to be a good way to force crash
|
// since there doesn't seem to be a good way to force crash
|
||||||
// No need to check for zero we're only allowing unsized.
|
// No need to check for zero we're only allowing unsized.
|
||||||
*/
|
*/
|
||||||
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
|
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
|
||||||
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
|
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
|
||||||
output[tid] = input[strided_src_i];
|
output[tid] = input[strided_src_i];
|
||||||
}
|
}
|
||||||
@ -68,25 +68,25 @@ kernel void NAME( \
|
|||||||
|
|
||||||
|
|
||||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||||
METAL_FUNC void gather(
|
METAL_FUNC void gather(
|
||||||
constant size_t &dst_size,
|
constant size_t &dst_size,
|
||||||
constant size_t &left_size,
|
constant size_t &left_size,
|
||||||
constant size_t &src_dim_size,
|
constant size_t &src_dim_size,
|
||||||
constant size_t &right_size,
|
constant size_t &right_size,
|
||||||
constant size_t &ids_size,
|
constant size_t &ids_size,
|
||||||
const device TYPENAME *input,
|
const device TYPENAME *input,
|
||||||
const device INDEX_TYPENAME *input_ids,
|
const device INDEX_TYPENAME *input_ids,
|
||||||
device TYPENAME *output,
|
device TYPENAME *output,
|
||||||
uint tid [[ thread_position_in_grid ]]
|
uint tid [[ thread_position_in_grid ]]
|
||||||
) {
|
) {
|
||||||
if (tid >= dst_size) {
|
if (tid >= dst_size) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const INDEX_TYPENAME input_i = input_ids[tid];
|
const INDEX_TYPENAME input_i = input_ids[tid];
|
||||||
const size_t right_rank_i = tid % right_size;
|
const size_t right_rank_i = tid % right_size;
|
||||||
const size_t left_rank_i = tid / right_size / ids_size;
|
const size_t left_rank_i = tid / right_size / ids_size;
|
||||||
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
|
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
|
||||||
output[tid] = input[src_i];
|
output[tid] = input[src_i];
|
||||||
}
|
}
|
||||||
|
|
||||||
# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||||
@ -105,27 +105,27 @@ kernel void NAME( \
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||||
METAL_FUNC void scatter_add(
|
METAL_FUNC void scatter_add(
|
||||||
constant size_t &dst_size,
|
constant size_t &dst_size,
|
||||||
constant size_t &left_size,
|
constant size_t &left_size,
|
||||||
constant size_t &src_dim_size,
|
constant size_t &src_dim_size,
|
||||||
constant size_t &right_size,
|
constant size_t &right_size,
|
||||||
constant size_t &dst_dim_size,
|
constant size_t &dst_dim_size,
|
||||||
const device TYPENAME *input,
|
const device TYPENAME *input,
|
||||||
const device INDEX_TYPENAME *input_ids,
|
const device INDEX_TYPENAME *input_ids,
|
||||||
device TYPENAME *output,
|
device TYPENAME *output,
|
||||||
uint tid [[ thread_position_in_grid ]]
|
uint tid [[ thread_position_in_grid ]]
|
||||||
) {
|
) {
|
||||||
if (tid >= dst_size) {
|
if (tid >= dst_size) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const size_t right_rank_i = tid % right_size;
|
const size_t right_rank_i = tid % right_size;
|
||||||
const size_t left_rank_i = tid / right_size;
|
const size_t left_rank_i = tid / right_size;
|
||||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||||
const INDEX_TYPENAME idx = input_ids[src_i];
|
const INDEX_TYPENAME idx = input_ids[src_i];
|
||||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||||
output[dst_i] += input[src_i];
|
output[dst_i] += input[src_i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,28 +145,28 @@ kernel void NAME( \
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||||
METAL_FUNC void index_add(
|
METAL_FUNC void index_add(
|
||||||
constant size_t &dst_size,
|
constant size_t &dst_size,
|
||||||
constant size_t &left_size,
|
constant size_t &left_size,
|
||||||
constant size_t &src_dim_size,
|
constant size_t &src_dim_size,
|
||||||
constant size_t &right_size,
|
constant size_t &right_size,
|
||||||
constant size_t &dst_dim_size,
|
constant size_t &dst_dim_size,
|
||||||
constant size_t &ids_dim_size,
|
constant size_t &ids_dim_size,
|
||||||
const device TYPENAME *input,
|
const device TYPENAME *input,
|
||||||
const device INDEX_TYPENAME *input_ids,
|
const device INDEX_TYPENAME *input_ids,
|
||||||
device TYPENAME *output,
|
device TYPENAME *output,
|
||||||
uint tid [[ thread_position_in_grid ]]
|
uint tid [[ thread_position_in_grid ]]
|
||||||
) {
|
) {
|
||||||
if (tid >= dst_size) {
|
if (tid >= dst_size) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const size_t right_rank_i = tid % right_size;
|
const size_t right_rank_i = tid % right_size;
|
||||||
const size_t left_rank_i = tid / right_size;
|
const size_t left_rank_i = tid / right_size;
|
||||||
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
||||||
const INDEX_TYPENAME idx = input_ids[j];
|
const INDEX_TYPENAME idx = input_ids[j];
|
||||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||||
output[dst_i] += input[src_i];
|
output[dst_i] += input[src_i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -214,6 +214,7 @@ GATHER_OP(gather_u32_f16, uint, half)
|
|||||||
#if defined(__HAVE_BFLOAT__)
|
#if defined(__HAVE_BFLOAT__)
|
||||||
GATHER_OP(gather_u32_bf16, uint, bfloat)
|
GATHER_OP(gather_u32_bf16, uint, bfloat)
|
||||||
#endif
|
#endif
|
||||||
|
GATHER_OP(gather_u32_u32, uint, uint)
|
||||||
|
|
||||||
SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
|
SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
|
||||||
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)
|
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)
|
||||||
|
Reference in New Issue
Block a user