Indexing with max-value results in zero/no-op. (#2940)

* Indexing with max-value results in zero/no-op.

* Add some testing.

* Also adapt the metal kernels.

* Another test.

* Fix.
This commit is contained in:
Laurent Mazare
2025-05-03 11:36:31 +02:00
committed by GitHub
parent 1fdfb58de5
commit e27b4700ad
5 changed files with 200 additions and 66 deletions

View File

@ -483,17 +483,22 @@ impl<I: IntDType> Map1 for Gather<'_, I> {
let start_dst_idx = start_dst_idx + i * dst_right_len;
for right_i in 0..dst_right_len {
let dst_idx = start_dst_idx + right_i;
let index = ids[dst_idx].as_usize();
if index >= src_dim_len {
Err(Error::InvalidIndex {
index,
size: src_dim_len,
op: "gather",
let index = ids[dst_idx];
if index == I::max_value() {
dst[dst_idx] = T::zero();
} else {
let index = index.as_usize();
if index >= src_dim_len {
Err(Error::InvalidIndex {
index,
size: src_dim_len,
op: "gather",
}
.bt())?
}
.bt())?
let src_idx = start_src_idx + index * src_right_len + right_i;
dst[dst_idx] = src[src_idx]
}
let src_idx = start_src_idx + index * src_right_len + right_i;
dst[dst_idx] = src[src_idx]
}
}
}
@ -535,19 +540,24 @@ impl<I: IntDType> Map1 for IndexSelect<'_, I> {
let start_src_idx = left_i * right_len * src_dim;
let start_dst_idx = left_i * right_len * n_ids;
for i in 0..n_ids {
let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize();
if index >= src_dim {
Err(Error::InvalidIndex {
index,
size: src_dim,
op: "index-select",
}
.bt())?
}
let start_src_idx = start_src_idx + index * right_len;
let start_dst_idx = start_dst_idx + i * right_len;
dst[start_dst_idx..start_dst_idx + right_len]
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
let index = self.ids[self.ids_l.start_offset() + stride_ids * i];
if index == I::max_value() {
dst[start_dst_idx..start_dst_idx + right_len].fill(T::zero());
} else {
let index = index.as_usize();
if index >= src_dim {
Err(Error::InvalidIndex {
index,
size: src_dim,
op: "index-select",
}
.bt())?
}
let start_src_idx = start_src_idx + index * right_len;
dst[start_dst_idx..start_dst_idx + right_len]
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
}
}
}
Ok(dst)
@ -631,7 +641,11 @@ impl<I: IntDType, M: ElemUpdate> Map2InPlace for Scatter<'_, I, M> {
let start_ids_idx = start_ids_idx + i * ids_right_len;
for right_i in 0..dst_right_len {
let ids_idx = start_ids_idx + right_i;
let index = ids[ids_idx].as_usize();
let index = ids[ids_idx];
if index == I::max_value() {
continue;
}
let index = index.as_usize();
if index >= dst_dim_len {
Err(Error::InvalidIndex {
index,
@ -674,6 +688,9 @@ impl<I: IntDType> Map2 for IndexAdd<'_, I> {
let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
if dim == 0 {
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
if *dst_idx == I::max_value() {
continue;
}
let dst_idx = dst_idx.as_usize();
if dst_idx >= max_idx {
Err(Error::InvalidIndex {
@ -692,6 +709,9 @@ impl<I: IntDType> Map2 for IndexAdd<'_, I> {
}
} else {
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
if *dst_idx == I::max_value() {
continue;
}
let dst_idx = dst_idx.as_usize();
if dst_idx >= max_idx {
Err(Error::InvalidIndex {

View File

@ -180,7 +180,7 @@ with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
pub trait IntDType: WithDType {
pub trait IntDType: WithDType + num_traits::Bounded {
fn is_true(&self) -> bool;
fn as_usize(&self) -> usize;
}

View File

@ -845,6 +845,9 @@ fn embeddings(device: &Device) -> Result<()> {
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let ids = Tensor::new(&[u32::MAX, 2u32, u32::MAX], device)?;
let hs = t.index_select(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]]);
Ok(())
}
@ -1087,6 +1090,31 @@ fn scatter(device: &Device) -> Result<()> {
[1.0, 1.0, 1.0]
]
);
let hs = {
let ids = Tensor::new(
&[
[0u32, u32::MAX, 2],
[3, 4, u32::MAX],
[3, 3, 1],
[u32::MAX, u32::MAX, 4],
],
device,
)?;
init.scatter(&ids, &t, 0)?
};
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 1.0, 1.0],
[1.0, 1.0, 8.0],
[1.0, 1.0, 2.0],
[6.0, 7.0, 1.0],
[1.0, 4.0, 11.0],
[1.0, 1.0, 1.0]
]
);
init.scatter_set(&ids, &t, 0)?;
assert_eq!(
init.to_vec2::<f32>()?,
@ -1099,6 +1127,7 @@ fn scatter(device: &Device) -> Result<()> {
[1.0, 1.0, 1.0]
]
);
Ok(())
}
@ -1132,6 +1161,23 @@ fn gather(device: &Device) -> Result<()> {
let hs = t.gather(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]);
let hs = {
let ids = Tensor::new(
&[
[0u32, 0u32],
[2u32, u32::MAX],
[u32::MAX, 1u32],
[0u32, 2u32],
],
device,
)?;
t.gather(&ids, 1)?
};
assert_eq!(
hs.to_vec2::<f32>()?,
&[[0.0, 0.0], [5.0, 0.0], [0.0, 7.0], [9.0, 11.0]]
);
// Random data
// Dim: 0

View File

@ -3,6 +3,28 @@
#include "cuda_utils.cuh"
#include<stdint.h>
template <typename T>
__host__ __device__
constexpr T max_value();
template <>
__host__ __device__
constexpr int64_t max_value<int64_t>() {
return 0x7FFFFFFFFFFFFFFFLL;
}
template <>
__host__ __device__
constexpr uint32_t max_value<uint32_t>() {
return 0xFFFFFFFFu;
}
template <>
__host__ __device__
constexpr uint8_t max_value<uint8_t>() {
return 0xFFu;
}
template<typename T, typename I>
__device__ void index_select(
const size_t numel,
@ -23,10 +45,14 @@ __device__ void index_select(
unsigned int left_i = dst_i / (ids_dim_size * right_size);
unsigned int id_i = dst_i / right_size % ids_dim_size;
unsigned int right_i = dst_i % right_size;
assert(ids[id_i] < src_dim_size);
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
out[dst_i] = inp[strided_i];
if (ids[id_i] == max_value<I>()) {
out[dst_i] = static_cast<T>(0);
} else {
assert(ids[id_i] < src_dim_size);
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
out[dst_i] = inp[strided_i];
}
}
}
@ -57,11 +83,15 @@ __device__ void gather(
) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
size_t post = i % right_size;
size_t idx = ids[i];
assert(idx < src_dim_size);
size_t pre = i / (right_size * ids_dim_size);
size_t src_i = (pre * src_dim_size + idx) * right_size + post;
out[i] = inp[src_i];
const I idx = ids[i];
if (ids[i] == max_value<I>()) {
out[i] = static_cast<T>(0);
} else {
assert(idx < src_dim_size);
size_t pre = i / (right_size * ids_dim_size);
size_t src_i = (pre * src_dim_size + idx) * right_size + post;
out[i] = inp[src_i];
}
}
}
@ -93,11 +123,13 @@ __device__ void index_add(
const size_t pre = i / right_size;
const size_t post = i % right_size;
for (unsigned int j = 0; j < ids_dim_size; ++j) {
const size_t idx = ids[j];
assert(idx < dst_dim_size);
const I idx = ids[j];
const size_t src_i = (pre * ids_dim_size + j) * right_size + post;
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
if (idx < max_value<I>()) {
assert(idx < dst_dim_size);
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
}
}
}
}
@ -130,10 +162,12 @@ __device__ void scatter(
const size_t post = i % right_size;
for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
const size_t idx = ids[src_i];
assert(idx < dst_dim_size);
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] = inp[src_i];
const I idx = ids[src_i];
if (idx < max_value<I>()) {
assert(idx < dst_dim_size);
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] = inp[src_i];
}
}
}
}
@ -154,10 +188,12 @@ __device__ void scatter_add(
const size_t post = i % right_size;
for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
const size_t idx = ids[src_i];
assert(idx < dst_dim_size);
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
const I idx = ids[src_i];
if (idx < max_value<I>()) {
assert(idx < dst_dim_size);
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
}
}
}
}

View File

@ -1,6 +1,24 @@
#include <metal_stdlib>
using namespace metal;
template <typename T>
inline T max_value();
template <>
inline int64_t max_value<int64_t>() {
return 0x7FFFFFFFFFFFFFFF;
}
template <>
inline uint32_t max_value<uint32_t>() {
return 0xFFFFFFFFu;
}
template <>
inline uint8_t max_value<uint8_t>() {
return 0xFF;
}
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
@ -35,17 +53,21 @@ METAL_FUNC void index(
return;
}
const size_t id_i = (tid / right_size) % ids_size;
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
/*
// Force prevent out of bounds indexing
// since there doesn't seem to be a good way to force crash
// No need to check for zero we're only allowing unsized.
*/
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
output[tid] = input[strided_src_i];
if (input_ids[id_i] == max_value<INDEX_TYPENAME>()) {
output[tid] = static_cast<TYPENAME>(0);
} else {
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
/*
// Force prevent out of bounds indexing
// since there doesn't seem to be a good way to force crash
// No need to check for zero we're only allowing unsized.
*/
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
output[tid] = input[strided_src_i];
}
}
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
@ -83,10 +105,14 @@ METAL_FUNC void gather(
return;
}
const INDEX_TYPENAME input_i = input_ids[tid];
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
output[tid] = input[src_i];
if (input_i == max_value<INDEX_TYPENAME>()) {
output[tid] = static_cast<TYPENAME>(0);
} else {
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
output[tid] = input[src_i];
}
}
# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
@ -124,8 +150,10 @@ METAL_FUNC void scatter(
for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const INDEX_TYPENAME idx = input_ids[src_i];
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] = input[src_i];
if (idx < max_value<INDEX_TYPENAME>()) {
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] = input[src_i];
}
}
}
@ -149,8 +177,10 @@ METAL_FUNC void scatter_add(
for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const INDEX_TYPENAME idx = input_ids[src_i];
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i];
if (idx < max_value<INDEX_TYPENAME>()) {
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i];
}
}
}
@ -204,9 +234,11 @@ METAL_FUNC void index_add(
const size_t left_rank_i = tid / right_size;
for (unsigned int j = 0; j < ids_dim_size; ++j) {
const INDEX_TYPENAME idx = input_ids[j];
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i];
if (idx < max_value<INDEX_TYPENAME>()) {
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i];
}
}
}