mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Optimize the index-select cuda kernel. (#976)
This commit is contained in:
@ -17,20 +17,14 @@ __device__ void index_select(
|
||||
) {
|
||||
const size_t *dims = info;
|
||||
const size_t *strides = info + num_dims;
|
||||
if (is_contiguous(num_dims, dims, strides)) {
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
for (unsigned int j = 0; j < left_size; ++j) {
|
||||
memcpy(&out[(i + j * numel) * right_size], &inp[(j * dim_size + ids[i]) * right_size], right_size * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
|
||||
for (unsigned int j = 0; j < left_size; ++j) {
|
||||
memcpy(&out[(i + j * numel) * right_size], &inp[(j * dim_size + ids[strided_i]) * right_size], right_size * sizeof(T));
|
||||
}
|
||||
}
|
||||
bool b = is_contiguous(num_dims, dims, strides);
|
||||
for (unsigned int dst_i = blockIdx.x * blockDim.x + threadIdx.x; dst_i < numel; dst_i += blockDim.x * gridDim.x) {
|
||||
unsigned int left_i = dst_i / (dim_size * right_size);
|
||||
unsigned int id_i = dst_i / right_size % dim_size;
|
||||
unsigned int right_i = dst_i % right_size;
|
||||
unsigned int src_i = left_i * (dim_size * right_size) + ids[id_i] * right_size + right_i;
|
||||
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
|
||||
out[dst_i] = inp[strided_i];
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user