Fix for the index-select cuda setup. (#1022)

* Fix for index-select.

* Better fix + add some testing.
This commit is contained in:
Laurent Mazare
2023-10-03 10:21:46 +01:00
committed by GitHub
parent 7b06872f90
commit 043cc25766
2 changed files with 16 additions and 1 deletions

View File

@ -891,7 +891,7 @@ impl<'a> Map1 for IndexSelect<'a> {
}; };
let left_size: usize = src_l.dims()[..self.2].iter().product(); let left_size: usize = src_l.dims()[..self.2].iter().product();
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product(); let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
let dim_size = src_l.dims()[self.2]; let dim_size = ids_shape.elem_count();
let dst_el = ids_shape.elem_count() * left_size * right_size; let dst_el = ids_shape.elem_count() * left_size * right_size;
let cfg = LaunchConfig::for_num_elems(dst_el as u32); let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?; let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;

View File

@ -653,6 +653,21 @@ fn index_select(device: &Device) -> Result<()> {
hs.to_vec2::<f32>()?, hs.to_vec2::<f32>()?,
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]] &[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
); );
// Prior to https://github.com/huggingface/candle/pull/1022
// There would be a bug where the last values in the result tensor would be set to 0.
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
let hs = t.index_select(&ids, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[6.0, 7.0, 8.0],
[3.0, 4.0, 5.0],
[0.0, 1.0, 2.0],
[6.0, 7.0, 8.0],
[3.0, 4.0, 5.0],
]
);
Ok(()) Ok(())
} }