diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 347710de..af7cb5bd 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -483,17 +483,22 @@ impl Map1 for Gather<'_, I> { let start_dst_idx = start_dst_idx + i * dst_right_len; for right_i in 0..dst_right_len { let dst_idx = start_dst_idx + right_i; - let index = ids[dst_idx].as_usize(); - if index >= src_dim_len { - Err(Error::InvalidIndex { - index, - size: src_dim_len, - op: "gather", + let index = ids[dst_idx]; + if index == I::max_value() { + dst[dst_idx] = T::zero(); + } else { + let index = index.as_usize(); + if index >= src_dim_len { + Err(Error::InvalidIndex { + index, + size: src_dim_len, + op: "gather", + } + .bt())? } - .bt())? + let src_idx = start_src_idx + index * src_right_len + right_i; + dst[dst_idx] = src[src_idx] } - let src_idx = start_src_idx + index * src_right_len + right_i; - dst[dst_idx] = src[src_idx] } } } @@ -535,19 +540,24 @@ impl Map1 for IndexSelect<'_, I> { let start_src_idx = left_i * right_len * src_dim; let start_dst_idx = left_i * right_len * n_ids; for i in 0..n_ids { - let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize(); - if index >= src_dim { - Err(Error::InvalidIndex { - index, - size: src_dim, - op: "index-select", - } - .bt())? - } - let start_src_idx = start_src_idx + index * right_len; let start_dst_idx = start_dst_idx + i * right_len; - dst[start_dst_idx..start_dst_idx + right_len] - .copy_from_slice(&src[start_src_idx..start_src_idx + right_len]) + let index = self.ids[self.ids_l.start_offset() + stride_ids * i]; + if index == I::max_value() { + dst[start_dst_idx..start_dst_idx + right_len].fill(T::zero()); + } else { + let index = index.as_usize(); + if index >= src_dim { + Err(Error::InvalidIndex { + index, + size: src_dim, + op: "index-select", + } + .bt())? + } + let start_src_idx = start_src_idx + index * right_len; + dst[start_dst_idx..start_dst_idx + right_len] + .copy_from_slice(&src[start_src_idx..start_src_idx + right_len]) + } } } Ok(dst) @@ -631,7 +641,11 @@ impl Map2InPlace for Scatter<'_, I, M> { let start_ids_idx = start_ids_idx + i * ids_right_len; for right_i in 0..dst_right_len { let ids_idx = start_ids_idx + right_i; - let index = ids[ids_idx].as_usize(); + let index = ids[ids_idx]; + if index == I::max_value() { + continue; + } + let index = index.as_usize(); if index >= dst_dim_len { Err(Error::InvalidIndex { index, @@ -674,6 +688,9 @@ impl Map2 for IndexAdd<'_, I> { let post_dim = src_l.dims()[dim + 1..].iter().product::(); if dim == 0 { for (src_idx, dst_idx) in self.ids.iter().enumerate() { + if *dst_idx == I::max_value() { + continue; + } let dst_idx = dst_idx.as_usize(); if dst_idx >= max_idx { Err(Error::InvalidIndex { @@ -692,6 +709,9 @@ impl Map2 for IndexAdd<'_, I> { } } else { for (src_idx, dst_idx) in self.ids.iter().enumerate() { + if *dst_idx == I::max_value() { + continue; + } let dst_idx = dst_idx.as_usize(); if dst_idx >= max_idx { Err(Error::InvalidIndex { diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 1908e600..b0697c19 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -180,7 +180,7 @@ with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64); with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64); with_dtype!(f64, F64, |v: f64| v, |v: f64| v); -pub trait IntDType: WithDType { +pub trait IntDType: WithDType + num_traits::Bounded { fn is_true(&self) -> bool; fn as_usize(&self) -> usize; } diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 309e705e..c443ad2a 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -845,6 +845,9 @@ fn embeddings(device: &Device) -> Result<()> { assert_eq!(hs.to_vec2::()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?; assert_eq!(hs.to_vec2::()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); + let ids = Tensor::new(&[u32::MAX, 2u32, u32::MAX], device)?; + let hs = t.index_select(&ids, 0)?; + assert_eq!(hs.to_vec2::()?, &[[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]]); Ok(()) } @@ -1087,6 +1090,31 @@ fn scatter(device: &Device) -> Result<()> { [1.0, 1.0, 1.0] ] ); + + let hs = { + let ids = Tensor::new( + &[ + [0u32, u32::MAX, 2], + [3, 4, u32::MAX], + [3, 3, 1], + [u32::MAX, u32::MAX, 4], + ], + device, + )?; + init.scatter(&ids, &t, 0)? + }; + assert_eq!( + hs.to_vec2::()?, + &[ + [0.0, 1.0, 1.0], + [1.0, 1.0, 8.0], + [1.0, 1.0, 2.0], + [6.0, 7.0, 1.0], + [1.0, 4.0, 11.0], + [1.0, 1.0, 1.0] + ] + ); + init.scatter_set(&ids, &t, 0)?; assert_eq!( init.to_vec2::()?, @@ -1099,6 +1127,7 @@ fn scatter(device: &Device) -> Result<()> { [1.0, 1.0, 1.0] ] ); + Ok(()) } @@ -1132,6 +1161,23 @@ fn gather(device: &Device) -> Result<()> { let hs = t.gather(&ids, 0)?; assert_eq!(hs.to_vec2::()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]); + let hs = { + let ids = Tensor::new( + &[ + [0u32, 0u32], + [2u32, u32::MAX], + [u32::MAX, 1u32], + [0u32, 2u32], + ], + device, + )?; + t.gather(&ids, 1)? + }; + assert_eq!( + hs.to_vec2::()?, + &[[0.0, 0.0], [5.0, 0.0], [0.0, 7.0], [9.0, 11.0]] + ); + // Random data // Dim: 0 diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index f2327f27..d023280d 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -3,6 +3,28 @@ #include "cuda_utils.cuh" #include +template +__host__ __device__ +constexpr T max_value(); + +template <> +__host__ __device__ +constexpr int64_t max_value() { + return 0x7FFFFFFFFFFFFFFFLL; +} + +template <> +__host__ __device__ +constexpr uint32_t max_value() { + return 0xFFFFFFFFu; +} + +template <> +__host__ __device__ +constexpr uint8_t max_value() { + return 0xFFu; +} + template __device__ void index_select( const size_t numel, @@ -23,10 +45,14 @@ __device__ void index_select( unsigned int left_i = dst_i / (ids_dim_size * right_size); unsigned int id_i = dst_i / right_size % ids_dim_size; unsigned int right_i = dst_i % right_size; - assert(ids[id_i] < src_dim_size); - unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i; - unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides); - out[dst_i] = inp[strided_i]; + if (ids[id_i] == max_value()) { + out[dst_i] = static_cast(0); + } else { + assert(ids[id_i] < src_dim_size); + unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i; + unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides); + out[dst_i] = inp[strided_i]; + } } } @@ -57,11 +83,15 @@ __device__ void gather( ) { for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { size_t post = i % right_size; - size_t idx = ids[i]; - assert(idx < src_dim_size); - size_t pre = i / (right_size * ids_dim_size); - size_t src_i = (pre * src_dim_size + idx) * right_size + post; - out[i] = inp[src_i]; + const I idx = ids[i]; + if (ids[i] == max_value()) { + out[i] = static_cast(0); + } else { + assert(idx < src_dim_size); + size_t pre = i / (right_size * ids_dim_size); + size_t src_i = (pre * src_dim_size + idx) * right_size + post; + out[i] = inp[src_i]; + } } } @@ -93,11 +123,13 @@ __device__ void index_add( const size_t pre = i / right_size; const size_t post = i % right_size; for (unsigned int j = 0; j < ids_dim_size; ++j) { - const size_t idx = ids[j]; - assert(idx < dst_dim_size); + const I idx = ids[j]; const size_t src_i = (pre * ids_dim_size + j) * right_size + post; - const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; - out[dst_i] += inp[src_i]; + if (idx < max_value()) { + assert(idx < dst_dim_size); + const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] += inp[src_i]; + } } } } @@ -130,10 +162,12 @@ __device__ void scatter( const size_t post = i % right_size; for (unsigned int j = 0; j < src_dim_size; ++j) { const size_t src_i = (pre * src_dim_size + j) * right_size + post; - const size_t idx = ids[src_i]; - assert(idx < dst_dim_size); - const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; - out[dst_i] = inp[src_i]; + const I idx = ids[src_i]; + if (idx < max_value()) { + assert(idx < dst_dim_size); + const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] = inp[src_i]; + } } } } @@ -154,10 +188,12 @@ __device__ void scatter_add( const size_t post = i % right_size; for (unsigned int j = 0; j < src_dim_size; ++j) { const size_t src_i = (pre * src_dim_size + j) * right_size + post; - const size_t idx = ids[src_i]; - assert(idx < dst_dim_size); - const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; - out[dst_i] += inp[src_i]; + const I idx = ids[src_i]; + if (idx < max_value()) { + assert(idx < dst_dim_size); + const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] += inp[src_i]; + } } } } diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index d596a619..4c0cf8c0 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -1,6 +1,24 @@ #include using namespace metal; +template +inline T max_value(); + +template <> +inline int64_t max_value() { + return 0x7FFFFFFFFFFFFFFF; +} + +template <> +inline uint32_t max_value() { + return 0xFFFFFFFFu; +} + +template <> +inline uint8_t max_value() { + return 0xFF; +} + METAL_FUNC uint get_strided_index( uint idx, constant size_t &num_dims, @@ -35,17 +53,21 @@ METAL_FUNC void index( return; } const size_t id_i = (tid / right_size) % ids_size; - const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); - const size_t right_rank_i = tid % right_size; - const size_t left_rank_i = tid / right_size / ids_size; - /* - // Force prevent out of bounds indexing - // since there doesn't seem to be a good way to force crash - // No need to check for zero we're only allowing unsized. - */ - const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; - const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides); - output[tid] = input[strided_src_i]; + if (input_ids[id_i] == max_value()) { + output[tid] = static_cast(0); + } else { + const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size / ids_size; + /* + // Force prevent out of bounds indexing + // since there doesn't seem to be a good way to force crash + // No need to check for zero we're only allowing unsized. + */ + const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; + const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides); + output[tid] = input[strided_src_i]; + } } # define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \ @@ -83,10 +105,14 @@ METAL_FUNC void gather( return; } const INDEX_TYPENAME input_i = input_ids[tid]; - const size_t right_rank_i = tid % right_size; - const size_t left_rank_i = tid / right_size / ids_size; - const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i; - output[tid] = input[src_i]; + if (input_i == max_value()) { + output[tid] = static_cast(0); + } else { + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size / ids_size; + const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i; + output[tid] = input[src_i]; + } } # define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \ @@ -124,8 +150,10 @@ METAL_FUNC void scatter( for (unsigned int j = 0; j < src_dim_size; ++j) { const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i; const INDEX_TYPENAME idx = input_ids[src_i]; - const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; - output[dst_i] = input[src_i]; + if (idx < max_value()) { + const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; + output[dst_i] = input[src_i]; + } } } @@ -149,8 +177,10 @@ METAL_FUNC void scatter_add( for (unsigned int j = 0; j < src_dim_size; ++j) { const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i; const INDEX_TYPENAME idx = input_ids[src_i]; - const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; - output[dst_i] += input[src_i]; + if (idx < max_value()) { + const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; + output[dst_i] += input[src_i]; + } } } @@ -204,9 +234,11 @@ METAL_FUNC void index_add( const size_t left_rank_i = tid / right_size; for (unsigned int j = 0; j < ids_dim_size; ++j) { const INDEX_TYPENAME idx = input_ids[j]; - const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i; - const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; - output[dst_i] += input[src_i]; + if (idx < max_value()) { + const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i; + const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; + output[dst_i] += input[src_i]; + } } }