diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 2da10f34..bbbe5faf 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -395,7 +395,7 @@ impl Map1 for IndexSelect<'_> { CudaStorageSlice::U8(slice) => ("is_u8", slice_ptr(slice, ids_l.start_offset())), CudaStorageSlice::I64(slice) => ("is_i64", slice_ptr(slice, ids_l.start_offset())), _ => Err(CudaError::UnexpectedDType { - msg: "index_select ids should be u8 or u32", + msg: "index_select ids should be u8, u32, or i64", expected: DType::U32, got: self.0.dtype(), }) diff --git a/candle-core/src/tensor_cat.rs b/candle-core/src/tensor_cat.rs index 20b805c7..520b246f 100644 --- a/candle-core/src/tensor_cat.rs +++ b/candle-core/src/tensor_cat.rs @@ -241,7 +241,7 @@ impl Tensor { /// `self` and `src` must have the same shape except on dimension `dim` where the `self` size /// has to be greater than or equal to `offset` plus the `src` size. /// - /// Note that this modifies `self` in place and as such is not compatibel with + /// Note that this modifies `self` in place and as such is not compatible with /// back-propagation. pub fn slice_set(&self, src: &Self, dim: D, offset: usize) -> Result<()> { let dim = dim.to_index(self.shape(), "slice-set")?; diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 36942ff2..168012c5 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -826,6 +826,31 @@ fn embeddings(device: &Device) -> Result<()> { Ok(()) } +#[test] +fn index_select_fail() -> Result<()> { + // Check that an error is properly reported on out of bounds. + let ids = Tensor::new(&[4u32, 2u32, 1u32], &Device::Cpu)?; + let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &Device::Cpu)?; + let hs = t.index_select(&ids, 0); + assert!(hs.is_err()); + Ok(()) +} + +// The test below triggers an unwinding panic as there is a panic within the +// #[cfg(feature = "cuda")] +// #[test] +// #[should_panic] +// fn index_select_fail_gpu() { +// // Check that a panic happens for out of bounds in cuda +// if let Ok(device) = Device::new_cuda(0) { +// if let Ok(ids) = Tensor::new(&[4u32, 2u32, 1u32], &device) { +// if let Ok(t) = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &device) { +// let _ = t.index_select(&ids, 0); +// } +// } +// } +// } + fn cmp(device: &Device) -> Result<()> { let t1 = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?; let t2 = Tensor::new(&[[1f32, 0f32], [3f32, 3f32], [4f32, 7f32]], device)?; diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index 8af2954d..7074fa0b 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -23,6 +23,7 @@ __device__ void index_select( unsigned int left_i = dst_i / (ids_dim_size * right_size); unsigned int id_i = dst_i / right_size % ids_dim_size; unsigned int right_i = dst_i % right_size; + assert(ids[id_i] < src_dim_size); unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i; unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides); out[dst_i] = inp[strided_i]; @@ -57,6 +58,7 @@ __device__ void gather( for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { size_t post = i % right_size; size_t idx = ids[i]; + assert(idx < src_dim_size); size_t pre = i / (right_size * ids_dim_size); size_t src_i = (pre * src_dim_size + idx) * right_size + post; out[i] = inp[src_i]; @@ -92,6 +94,7 @@ __device__ void index_add( const size_t post = i % right_size; for (unsigned int j = 0; j < ids_dim_size; ++j) { const size_t idx = ids[j]; + assert(idx < dst_dim_size); const size_t src_i = (pre * ids_dim_size + j) * right_size + post; const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; out[dst_i] += inp[src_i]; @@ -128,6 +131,7 @@ __device__ void scatter_add( for (unsigned int j = 0; j < src_dim_size; ++j) { const size_t src_i = (pre * src_dim_size + j) * right_size + post; const size_t idx = ids[src_i]; + assert(idx < dst_dim_size); const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; out[dst_i] += inp[src_i]; }