mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0 (#1037)
* fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0 * cargo fmt
This commit is contained in:
@ -12,17 +12,18 @@ __device__ void index_select(
|
||||
const T *inp,
|
||||
T *out,
|
||||
const size_t left_size,
|
||||
const size_t dim_size,
|
||||
const size_t src_dim_size,
|
||||
const size_t ids_dim_size,
|
||||
const size_t right_size
|
||||
) {
|
||||
const size_t *dims = info;
|
||||
const size_t *strides = info + num_dims;
|
||||
bool b = is_contiguous(num_dims, dims, strides);
|
||||
for (unsigned int dst_i = blockIdx.x * blockDim.x + threadIdx.x; dst_i < numel; dst_i += blockDim.x * gridDim.x) {
|
||||
unsigned int left_i = dst_i / (dim_size * right_size);
|
||||
unsigned int id_i = dst_i / right_size % dim_size;
|
||||
unsigned int left_i = dst_i / (ids_dim_size * right_size);
|
||||
unsigned int id_i = dst_i / right_size % ids_dim_size;
|
||||
unsigned int right_i = dst_i % right_size;
|
||||
unsigned int src_i = left_i * (dim_size * right_size) + ids[id_i] * right_size + right_i;
|
||||
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
|
||||
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
|
||||
out[dst_i] = inp[strided_i];
|
||||
}
|
||||
@ -37,9 +38,10 @@ extern "C" __global__ void FN_NAME( \
|
||||
const TYPENAME *inp, \
|
||||
TYPENAME *out, \
|
||||
const size_t left_size, \
|
||||
const size_t dim_size, \
|
||||
const size_t src_dim_size, \
|
||||
const size_t ids_dim_size, \
|
||||
const size_t right_size \
|
||||
) { index_select(numel, num_dims, info, ids, inp, out, left_size, dim_size, right_size); } \
|
||||
) { index_select(numel, num_dims, info, ids, inp, out, left_size, src_dim_size, ids_dim_size, right_size); } \
|
||||
|
||||
template<typename T, typename I>
|
||||
__device__ void gather(
|
||||
|
Reference in New Issue
Block a user