fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0 (#1037)

* fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0

* cargo fmt
This commit is contained in:
Gonzalo
2023-10-05 14:46:13 -03:00
committed by GitHub
parent f0c619a4af
commit 8f7973958c
3 changed files with 21 additions and 8 deletions

View File

@ -891,7 +891,8 @@ impl<'a> Map1 for IndexSelect<'a> {
};
let left_size: usize = src_l.dims()[..self.2].iter().product();
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
let dim_size = ids_shape.elem_count();
let src_dim_size = src_l.dims()[self.2];
let ids_dim_size = ids_shape.elem_count();
let dst_el = ids_shape.elem_count() * left_size * right_size;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
@ -905,7 +906,8 @@ impl<'a> Map1 for IndexSelect<'a> {
&src,
&out,
left_size,
dim_size,
src_dim_size,
ids_dim_size,
right_size,
);
// SAFETY: ffi.

View File

@ -680,6 +680,15 @@ fn index_select(device: &Device) -> Result<()> {
[3.0, 4.0, 5.0],
]
);
// Test when selecting dim > 0 with ids size different from elem count of
// target dim in source/input.
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
let hs = t.index_select(&ids, 1)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
Ok(())
}