mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
Fix for the index-select cuda setup. (#1022)
* Fix for index-select. * Better fix + add some testing.
This commit is contained in:
@ -891,7 +891,7 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
};
|
||||
let left_size: usize = src_l.dims()[..self.2].iter().product();
|
||||
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
|
||||
let dim_size = src_l.dims()[self.2];
|
||||
let dim_size = ids_shape.elem_count();
|
||||
let dst_el = ids_shape.elem_count() * left_size * right_size;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
||||
|
Reference in New Issue
Block a user