Check the bounds in the cuda indexing kernels. (#2908)

* Check the bounds in the cuda indexing kernels.

* Another check.
This commit is contained in:
Laurent Mazare
2025-04-18 20:08:17 +02:00
committed by GitHub
parent 9954981327
commit ce5f8dd129
4 changed files with 31 additions and 2 deletions

View File

@ -395,7 +395,7 @@ impl Map1 for IndexSelect<'_> {
CudaStorageSlice::U8(slice) => ("is_u8", slice_ptr(slice, ids_l.start_offset())), CudaStorageSlice::U8(slice) => ("is_u8", slice_ptr(slice, ids_l.start_offset())),
CudaStorageSlice::I64(slice) => ("is_i64", slice_ptr(slice, ids_l.start_offset())), CudaStorageSlice::I64(slice) => ("is_i64", slice_ptr(slice, ids_l.start_offset())),
_ => Err(CudaError::UnexpectedDType { _ => Err(CudaError::UnexpectedDType {
msg: "index_select ids should be u8 or u32", msg: "index_select ids should be u8, u32, or i64",
expected: DType::U32, expected: DType::U32,
got: self.0.dtype(), got: self.0.dtype(),
}) })

View File

@ -241,7 +241,7 @@ impl Tensor {
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size /// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
/// has to be greater than or equal to `offset` plus the `src` size. /// has to be greater than or equal to `offset` plus the `src` size.
/// ///
/// Note that this modifies `self` in place and as such is not compatibel with /// Note that this modifies `self` in place and as such is not compatible with
/// back-propagation. /// back-propagation.
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> { pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
let dim = dim.to_index(self.shape(), "slice-set")?; let dim = dim.to_index(self.shape(), "slice-set")?;

View File

@ -826,6 +826,31 @@ fn embeddings(device: &Device) -> Result<()> {
Ok(()) Ok(())
} }
#[test]
fn index_select_fail() -> Result<()> {
// Check that an error is properly reported on out of bounds.
let ids = Tensor::new(&[4u32, 2u32, 1u32], &Device::Cpu)?;
let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &Device::Cpu)?;
let hs = t.index_select(&ids, 0);
assert!(hs.is_err());
Ok(())
}
// The test below triggers an unwinding panic as there is a panic within the
// #[cfg(feature = "cuda")]
// #[test]
// #[should_panic]
// fn index_select_fail_gpu() {
// // Check that a panic happens for out of bounds in cuda
// if let Ok(device) = Device::new_cuda(0) {
// if let Ok(ids) = Tensor::new(&[4u32, 2u32, 1u32], &device) {
// if let Ok(t) = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &device) {
// let _ = t.index_select(&ids, 0);
// }
// }
// }
// }
fn cmp(device: &Device) -> Result<()> { fn cmp(device: &Device) -> Result<()> {
let t1 = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?; let t1 = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
let t2 = Tensor::new(&[[1f32, 0f32], [3f32, 3f32], [4f32, 7f32]], device)?; let t2 = Tensor::new(&[[1f32, 0f32], [3f32, 3f32], [4f32, 7f32]], device)?;

View File

@ -23,6 +23,7 @@ __device__ void index_select(
unsigned int left_i = dst_i / (ids_dim_size * right_size); unsigned int left_i = dst_i / (ids_dim_size * right_size);
unsigned int id_i = dst_i / right_size % ids_dim_size; unsigned int id_i = dst_i / right_size % ids_dim_size;
unsigned int right_i = dst_i % right_size; unsigned int right_i = dst_i % right_size;
assert(ids[id_i] < src_dim_size);
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i; unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides); unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
out[dst_i] = inp[strided_i]; out[dst_i] = inp[strided_i];
@ -57,6 +58,7 @@ __device__ void gather(
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
size_t post = i % right_size; size_t post = i % right_size;
size_t idx = ids[i]; size_t idx = ids[i];
assert(idx < src_dim_size);
size_t pre = i / (right_size * ids_dim_size); size_t pre = i / (right_size * ids_dim_size);
size_t src_i = (pre * src_dim_size + idx) * right_size + post; size_t src_i = (pre * src_dim_size + idx) * right_size + post;
out[i] = inp[src_i]; out[i] = inp[src_i];
@ -92,6 +94,7 @@ __device__ void index_add(
const size_t post = i % right_size; const size_t post = i % right_size;
for (unsigned int j = 0; j < ids_dim_size; ++j) { for (unsigned int j = 0; j < ids_dim_size; ++j) {
const size_t idx = ids[j]; const size_t idx = ids[j];
assert(idx < dst_dim_size);
const size_t src_i = (pre * ids_dim_size + j) * right_size + post; const size_t src_i = (pre * ids_dim_size + j) * right_size + post;
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i]; out[dst_i] += inp[src_i];
@ -128,6 +131,7 @@ __device__ void scatter_add(
for (unsigned int j = 0; j < src_dim_size; ++j) { for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (pre * src_dim_size + j) * right_size + post; const size_t src_i = (pre * src_dim_size + j) * right_size + post;
const size_t idx = ids[src_i]; const size_t idx = ids[src_i];
assert(idx < dst_dim_size);
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i]; out[dst_i] += inp[src_i];
} }