From 043cc257665fb2398483bb3fb365a18c2ec1e010 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 3 Oct 2023 10:21:46 +0100 Subject: [PATCH] Fix for the index-select cuda setup. (#1022) * Fix for index-select. * Better fix + add some testing. --- candle-core/src/cuda_backend.rs | 2 +- candle-core/tests/tensor_tests.rs | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index d1187b1c..1599425f 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -891,7 +891,7 @@ impl<'a> Map1 for IndexSelect<'a> { }; let left_size: usize = src_l.dims()[..self.2].iter().product(); let right_size: usize = src_l.dims()[self.2 + 1..].iter().product(); - let dim_size = src_l.dims()[self.2]; + let dim_size = ids_shape.elem_count(); let dst_el = ids_shape.elem_count() * left_size * right_size; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index d3eede48..2f880158 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -653,6 +653,21 @@ fn index_select(device: &Device) -> Result<()> { hs.to_vec2::()?, &[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]] ); + // Prior to https://github.com/huggingface/candle/pull/1022 + // There would be a bug where the last values in the result tensor would be set to 0. + let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?; + let hs = t.index_select(&ids, 0)?; + assert_eq!( + hs.to_vec2::()?, + &[ + [0.0, 1.0, 2.0], + [6.0, 7.0, 8.0], + [3.0, 4.0, 5.0], + [0.0, 1.0, 2.0], + [6.0, 7.0, 8.0], + [3.0, 4.0, 5.0], + ] + ); Ok(()) }