Fix for the index-select cuda setup. (#1022)

* Fix for index-select.

* Better fix + add some testing.
This commit is contained in:
Laurent Mazare
2023-10-03 10:21:46 +01:00
committed by GitHub
parent 7b06872f90
commit 043cc25766
2 changed files with 16 additions and 1 deletions

View File

@ -891,7 +891,7 @@ impl<'a> Map1 for IndexSelect<'a> {
};
let left_size: usize = src_l.dims()[..self.2].iter().product();
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
let dim_size = src_l.dims()[self.2];
let dim_size = ids_shape.elem_count();
let dst_el = ids_shape.elem_count() * left_size * right_size;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;