From 8f7973958c55324a24f0c514e7ac6ded6681980f Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 5 Oct 2023 14:46:13 -0300 Subject: [PATCH] fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0 (#1037) * fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0 * cargo fmt --- candle-core/src/cuda_backend.rs | 6 ++++-- candle-core/tests/tensor_tests.rs | 9 +++++++++ candle-kernels/src/indexing.cu | 14 ++++++++------ 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 1599425f..f7518067 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -891,7 +891,8 @@ impl<'a> Map1 for IndexSelect<'a> { }; let left_size: usize = src_l.dims()[..self.2].iter().product(); let right_size: usize = src_l.dims()[self.2 + 1..].iter().product(); - let dim_size = ids_shape.elem_count(); + let src_dim_size = src_l.dims()[self.2]; + let ids_dim_size = ids_shape.elem_count(); let dst_el = ids_shape.elem_count() * left_size * right_size; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; @@ -905,7 +906,8 @@ impl<'a> Map1 for IndexSelect<'a> { &src, &out, left_size, - dim_size, + src_dim_size, + ids_dim_size, right_size, ); // SAFETY: ffi. diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 2e867b26..a50f3a6c 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -680,6 +680,15 @@ fn index_select(device: &Device) -> Result<()> { [3.0, 4.0, 5.0], ] ); + + // Test when selecting dim > 0 with ids size different from elem count of + // target dim in source/input. + let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?; + let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?; + assert_eq!(t.to_vec2::()?, &[[1.0, 2.0], [3.0, 4.0]]); + let hs = t.index_select(&ids, 1)?; + assert_eq!(hs.to_vec2::()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]); + Ok(()) } diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index 0272a330..8fc69363 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -12,17 +12,18 @@ __device__ void index_select( const T *inp, T *out, const size_t left_size, - const size_t dim_size, + const size_t src_dim_size, + const size_t ids_dim_size, const size_t right_size ) { const size_t *dims = info; const size_t *strides = info + num_dims; bool b = is_contiguous(num_dims, dims, strides); for (unsigned int dst_i = blockIdx.x * blockDim.x + threadIdx.x; dst_i < numel; dst_i += blockDim.x * gridDim.x) { - unsigned int left_i = dst_i / (dim_size * right_size); - unsigned int id_i = dst_i / right_size % dim_size; + unsigned int left_i = dst_i / (ids_dim_size * right_size); + unsigned int id_i = dst_i / right_size % ids_dim_size; unsigned int right_i = dst_i % right_size; - unsigned int src_i = left_i * (dim_size * right_size) + ids[id_i] * right_size + right_i; + unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i; unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides); out[dst_i] = inp[strided_i]; } @@ -37,9 +38,10 @@ extern "C" __global__ void FN_NAME( \ const TYPENAME *inp, \ TYPENAME *out, \ const size_t left_size, \ - const size_t dim_size, \ + const size_t src_dim_size, \ + const size_t ids_dim_size, \ const size_t right_size \ -) { index_select(numel, num_dims, info, ids, inp, out, left_size, dim_size, right_size); } \ +) { index_select(numel, num_dims, info, ids, inp, out, left_size, src_dim_size, ids_dim_size, right_size); } \ template __device__ void gather(