mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Indexing with max-value results in zero/no-op. (#2940)
* Indexing with max-value results in zero/no-op. * Add some testing. * Also adapt the metal kernels. * Another test. * Fix.
This commit is contained in:
@ -483,7 +483,11 @@ impl<I: IntDType> Map1 for Gather<'_, I> {
|
|||||||
let start_dst_idx = start_dst_idx + i * dst_right_len;
|
let start_dst_idx = start_dst_idx + i * dst_right_len;
|
||||||
for right_i in 0..dst_right_len {
|
for right_i in 0..dst_right_len {
|
||||||
let dst_idx = start_dst_idx + right_i;
|
let dst_idx = start_dst_idx + right_i;
|
||||||
let index = ids[dst_idx].as_usize();
|
let index = ids[dst_idx];
|
||||||
|
if index == I::max_value() {
|
||||||
|
dst[dst_idx] = T::zero();
|
||||||
|
} else {
|
||||||
|
let index = index.as_usize();
|
||||||
if index >= src_dim_len {
|
if index >= src_dim_len {
|
||||||
Err(Error::InvalidIndex {
|
Err(Error::InvalidIndex {
|
||||||
index,
|
index,
|
||||||
@ -497,6 +501,7 @@ impl<I: IntDType> Map1 for Gather<'_, I> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -535,7 +540,12 @@ impl<I: IntDType> Map1 for IndexSelect<'_, I> {
|
|||||||
let start_src_idx = left_i * right_len * src_dim;
|
let start_src_idx = left_i * right_len * src_dim;
|
||||||
let start_dst_idx = left_i * right_len * n_ids;
|
let start_dst_idx = left_i * right_len * n_ids;
|
||||||
for i in 0..n_ids {
|
for i in 0..n_ids {
|
||||||
let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize();
|
let start_dst_idx = start_dst_idx + i * right_len;
|
||||||
|
let index = self.ids[self.ids_l.start_offset() + stride_ids * i];
|
||||||
|
if index == I::max_value() {
|
||||||
|
dst[start_dst_idx..start_dst_idx + right_len].fill(T::zero());
|
||||||
|
} else {
|
||||||
|
let index = index.as_usize();
|
||||||
if index >= src_dim {
|
if index >= src_dim {
|
||||||
Err(Error::InvalidIndex {
|
Err(Error::InvalidIndex {
|
||||||
index,
|
index,
|
||||||
@ -545,11 +555,11 @@ impl<I: IntDType> Map1 for IndexSelect<'_, I> {
|
|||||||
.bt())?
|
.bt())?
|
||||||
}
|
}
|
||||||
let start_src_idx = start_src_idx + index * right_len;
|
let start_src_idx = start_src_idx + index * right_len;
|
||||||
let start_dst_idx = start_dst_idx + i * right_len;
|
|
||||||
dst[start_dst_idx..start_dst_idx + right_len]
|
dst[start_dst_idx..start_dst_idx + right_len]
|
||||||
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
|
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -631,7 +641,11 @@ impl<I: IntDType, M: ElemUpdate> Map2InPlace for Scatter<'_, I, M> {
|
|||||||
let start_ids_idx = start_ids_idx + i * ids_right_len;
|
let start_ids_idx = start_ids_idx + i * ids_right_len;
|
||||||
for right_i in 0..dst_right_len {
|
for right_i in 0..dst_right_len {
|
||||||
let ids_idx = start_ids_idx + right_i;
|
let ids_idx = start_ids_idx + right_i;
|
||||||
let index = ids[ids_idx].as_usize();
|
let index = ids[ids_idx];
|
||||||
|
if index == I::max_value() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let index = index.as_usize();
|
||||||
if index >= dst_dim_len {
|
if index >= dst_dim_len {
|
||||||
Err(Error::InvalidIndex {
|
Err(Error::InvalidIndex {
|
||||||
index,
|
index,
|
||||||
@ -674,6 +688,9 @@ impl<I: IntDType> Map2 for IndexAdd<'_, I> {
|
|||||||
let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
|
let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
|
||||||
if dim == 0 {
|
if dim == 0 {
|
||||||
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
|
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
|
||||||
|
if *dst_idx == I::max_value() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
let dst_idx = dst_idx.as_usize();
|
let dst_idx = dst_idx.as_usize();
|
||||||
if dst_idx >= max_idx {
|
if dst_idx >= max_idx {
|
||||||
Err(Error::InvalidIndex {
|
Err(Error::InvalidIndex {
|
||||||
@ -692,6 +709,9 @@ impl<I: IntDType> Map2 for IndexAdd<'_, I> {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
|
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
|
||||||
|
if *dst_idx == I::max_value() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
let dst_idx = dst_idx.as_usize();
|
let dst_idx = dst_idx.as_usize();
|
||||||
if dst_idx >= max_idx {
|
if dst_idx >= max_idx {
|
||||||
Err(Error::InvalidIndex {
|
Err(Error::InvalidIndex {
|
||||||
|
@ -180,7 +180,7 @@ with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
|
|||||||
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
|
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
|
||||||
with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
|
with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
|
||||||
|
|
||||||
pub trait IntDType: WithDType {
|
pub trait IntDType: WithDType + num_traits::Bounded {
|
||||||
fn is_true(&self) -> bool;
|
fn is_true(&self) -> bool;
|
||||||
fn as_usize(&self) -> usize;
|
fn as_usize(&self) -> usize;
|
||||||
}
|
}
|
||||||
|
@ -845,6 +845,9 @@ fn embeddings(device: &Device) -> Result<()> {
|
|||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||||
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
|
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
|
||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||||
|
let ids = Tensor::new(&[u32::MAX, 2u32, u32::MAX], device)?;
|
||||||
|
let hs = t.index_select(&ids, 0)?;
|
||||||
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1087,6 +1090,31 @@ fn scatter(device: &Device) -> Result<()> {
|
|||||||
[1.0, 1.0, 1.0]
|
[1.0, 1.0, 1.0]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let hs = {
|
||||||
|
let ids = Tensor::new(
|
||||||
|
&[
|
||||||
|
[0u32, u32::MAX, 2],
|
||||||
|
[3, 4, u32::MAX],
|
||||||
|
[3, 3, 1],
|
||||||
|
[u32::MAX, u32::MAX, 4],
|
||||||
|
],
|
||||||
|
device,
|
||||||
|
)?;
|
||||||
|
init.scatter(&ids, &t, 0)?
|
||||||
|
};
|
||||||
|
assert_eq!(
|
||||||
|
hs.to_vec2::<f32>()?,
|
||||||
|
&[
|
||||||
|
[0.0, 1.0, 1.0],
|
||||||
|
[1.0, 1.0, 8.0],
|
||||||
|
[1.0, 1.0, 2.0],
|
||||||
|
[6.0, 7.0, 1.0],
|
||||||
|
[1.0, 4.0, 11.0],
|
||||||
|
[1.0, 1.0, 1.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
init.scatter_set(&ids, &t, 0)?;
|
init.scatter_set(&ids, &t, 0)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
init.to_vec2::<f32>()?,
|
init.to_vec2::<f32>()?,
|
||||||
@ -1099,6 +1127,7 @@ fn scatter(device: &Device) -> Result<()> {
|
|||||||
[1.0, 1.0, 1.0]
|
[1.0, 1.0, 1.0]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1132,6 +1161,23 @@ fn gather(device: &Device) -> Result<()> {
|
|||||||
let hs = t.gather(&ids, 0)?;
|
let hs = t.gather(&ids, 0)?;
|
||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]);
|
||||||
|
|
||||||
|
let hs = {
|
||||||
|
let ids = Tensor::new(
|
||||||
|
&[
|
||||||
|
[0u32, 0u32],
|
||||||
|
[2u32, u32::MAX],
|
||||||
|
[u32::MAX, 1u32],
|
||||||
|
[0u32, 2u32],
|
||||||
|
],
|
||||||
|
device,
|
||||||
|
)?;
|
||||||
|
t.gather(&ids, 1)?
|
||||||
|
};
|
||||||
|
assert_eq!(
|
||||||
|
hs.to_vec2::<f32>()?,
|
||||||
|
&[[0.0, 0.0], [5.0, 0.0], [0.0, 7.0], [9.0, 11.0]]
|
||||||
|
);
|
||||||
|
|
||||||
// Random data
|
// Random data
|
||||||
|
|
||||||
// Dim: 0
|
// Dim: 0
|
||||||
|
@ -3,6 +3,28 @@
|
|||||||
#include "cuda_utils.cuh"
|
#include "cuda_utils.cuh"
|
||||||
#include<stdint.h>
|
#include<stdint.h>
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__host__ __device__
|
||||||
|
constexpr T max_value();
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__host__ __device__
|
||||||
|
constexpr int64_t max_value<int64_t>() {
|
||||||
|
return 0x7FFFFFFFFFFFFFFFLL;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__host__ __device__
|
||||||
|
constexpr uint32_t max_value<uint32_t>() {
|
||||||
|
return 0xFFFFFFFFu;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__host__ __device__
|
||||||
|
constexpr uint8_t max_value<uint8_t>() {
|
||||||
|
return 0xFFu;
|
||||||
|
}
|
||||||
|
|
||||||
template<typename T, typename I>
|
template<typename T, typename I>
|
||||||
__device__ void index_select(
|
__device__ void index_select(
|
||||||
const size_t numel,
|
const size_t numel,
|
||||||
@ -23,12 +45,16 @@ __device__ void index_select(
|
|||||||
unsigned int left_i = dst_i / (ids_dim_size * right_size);
|
unsigned int left_i = dst_i / (ids_dim_size * right_size);
|
||||||
unsigned int id_i = dst_i / right_size % ids_dim_size;
|
unsigned int id_i = dst_i / right_size % ids_dim_size;
|
||||||
unsigned int right_i = dst_i % right_size;
|
unsigned int right_i = dst_i % right_size;
|
||||||
|
if (ids[id_i] == max_value<I>()) {
|
||||||
|
out[dst_i] = static_cast<T>(0);
|
||||||
|
} else {
|
||||||
assert(ids[id_i] < src_dim_size);
|
assert(ids[id_i] < src_dim_size);
|
||||||
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
|
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
|
||||||
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
|
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
|
||||||
out[dst_i] = inp[strided_i];
|
out[dst_i] = inp[strided_i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#define IS_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
#define IS_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||||
extern "C" __global__ void FN_NAME( \
|
extern "C" __global__ void FN_NAME( \
|
||||||
@ -57,13 +83,17 @@ __device__ void gather(
|
|||||||
) {
|
) {
|
||||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||||
size_t post = i % right_size;
|
size_t post = i % right_size;
|
||||||
size_t idx = ids[i];
|
const I idx = ids[i];
|
||||||
|
if (ids[i] == max_value<I>()) {
|
||||||
|
out[i] = static_cast<T>(0);
|
||||||
|
} else {
|
||||||
assert(idx < src_dim_size);
|
assert(idx < src_dim_size);
|
||||||
size_t pre = i / (right_size * ids_dim_size);
|
size_t pre = i / (right_size * ids_dim_size);
|
||||||
size_t src_i = (pre * src_dim_size + idx) * right_size + post;
|
size_t src_i = (pre * src_dim_size + idx) * right_size + post;
|
||||||
out[i] = inp[src_i];
|
out[i] = inp[src_i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#define GATHER_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
#define GATHER_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||||
extern "C" __global__ void FN_NAME( \
|
extern "C" __global__ void FN_NAME( \
|
||||||
@ -93,14 +123,16 @@ __device__ void index_add(
|
|||||||
const size_t pre = i / right_size;
|
const size_t pre = i / right_size;
|
||||||
const size_t post = i % right_size;
|
const size_t post = i % right_size;
|
||||||
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
||||||
const size_t idx = ids[j];
|
const I idx = ids[j];
|
||||||
assert(idx < dst_dim_size);
|
|
||||||
const size_t src_i = (pre * ids_dim_size + j) * right_size + post;
|
const size_t src_i = (pre * ids_dim_size + j) * right_size + post;
|
||||||
|
if (idx < max_value<I>()) {
|
||||||
|
assert(idx < dst_dim_size);
|
||||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||||
out[dst_i] += inp[src_i];
|
out[dst_i] += inp[src_i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||||
extern "C" __global__ void FN_NAME( \
|
extern "C" __global__ void FN_NAME( \
|
||||||
@ -130,13 +162,15 @@ __device__ void scatter(
|
|||||||
const size_t post = i % right_size;
|
const size_t post = i % right_size;
|
||||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||||
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
|
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
|
||||||
const size_t idx = ids[src_i];
|
const I idx = ids[src_i];
|
||||||
|
if (idx < max_value<I>()) {
|
||||||
assert(idx < dst_dim_size);
|
assert(idx < dst_dim_size);
|
||||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||||
out[dst_i] = inp[src_i];
|
out[dst_i] = inp[src_i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template<typename T, typename I>
|
template<typename T, typename I>
|
||||||
__device__ void scatter_add(
|
__device__ void scatter_add(
|
||||||
@ -154,13 +188,15 @@ __device__ void scatter_add(
|
|||||||
const size_t post = i % right_size;
|
const size_t post = i % right_size;
|
||||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||||
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
|
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
|
||||||
const size_t idx = ids[src_i];
|
const I idx = ids[src_i];
|
||||||
|
if (idx < max_value<I>()) {
|
||||||
assert(idx < dst_dim_size);
|
assert(idx < dst_dim_size);
|
||||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||||
out[dst_i] += inp[src_i];
|
out[dst_i] += inp[src_i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#define S_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
#define S_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||||
extern "C" __global__ void FN_NAME( \
|
extern "C" __global__ void FN_NAME( \
|
||||||
|
@ -1,6 +1,24 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T max_value();
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline int64_t max_value<int64_t>() {
|
||||||
|
return 0x7FFFFFFFFFFFFFFF;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline uint32_t max_value<uint32_t>() {
|
||||||
|
return 0xFFFFFFFFu;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline uint8_t max_value<uint8_t>() {
|
||||||
|
return 0xFF;
|
||||||
|
}
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
METAL_FUNC uint get_strided_index(
|
||||||
uint idx,
|
uint idx,
|
||||||
constant size_t &num_dims,
|
constant size_t &num_dims,
|
||||||
@ -35,6 +53,9 @@ METAL_FUNC void index(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const size_t id_i = (tid / right_size) % ids_size;
|
const size_t id_i = (tid / right_size) % ids_size;
|
||||||
|
if (input_ids[id_i] == max_value<INDEX_TYPENAME>()) {
|
||||||
|
output[tid] = static_cast<TYPENAME>(0);
|
||||||
|
} else {
|
||||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
|
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
|
||||||
const size_t right_rank_i = tid % right_size;
|
const size_t right_rank_i = tid % right_size;
|
||||||
const size_t left_rank_i = tid / right_size / ids_size;
|
const size_t left_rank_i = tid / right_size / ids_size;
|
||||||
@ -47,6 +68,7 @@ METAL_FUNC void index(
|
|||||||
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
|
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
|
||||||
output[tid] = input[strided_src_i];
|
output[tid] = input[strided_src_i];
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||||
kernel void NAME( \
|
kernel void NAME( \
|
||||||
@ -83,11 +105,15 @@ METAL_FUNC void gather(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const INDEX_TYPENAME input_i = input_ids[tid];
|
const INDEX_TYPENAME input_i = input_ids[tid];
|
||||||
|
if (input_i == max_value<INDEX_TYPENAME>()) {
|
||||||
|
output[tid] = static_cast<TYPENAME>(0);
|
||||||
|
} else {
|
||||||
const size_t right_rank_i = tid % right_size;
|
const size_t right_rank_i = tid % right_size;
|
||||||
const size_t left_rank_i = tid / right_size / ids_size;
|
const size_t left_rank_i = tid / right_size / ids_size;
|
||||||
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
|
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
|
||||||
output[tid] = input[src_i];
|
output[tid] = input[src_i];
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||||
kernel void NAME( \
|
kernel void NAME( \
|
||||||
@ -124,10 +150,12 @@ METAL_FUNC void scatter(
|
|||||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||||
const INDEX_TYPENAME idx = input_ids[src_i];
|
const INDEX_TYPENAME idx = input_ids[src_i];
|
||||||
|
if (idx < max_value<INDEX_TYPENAME>()) {
|
||||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||||
output[dst_i] = input[src_i];
|
output[dst_i] = input[src_i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||||
METAL_FUNC void scatter_add(
|
METAL_FUNC void scatter_add(
|
||||||
@ -149,10 +177,12 @@ METAL_FUNC void scatter_add(
|
|||||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||||
const INDEX_TYPENAME idx = input_ids[src_i];
|
const INDEX_TYPENAME idx = input_ids[src_i];
|
||||||
|
if (idx < max_value<INDEX_TYPENAME>()) {
|
||||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||||
output[dst_i] += input[src_i];
|
output[dst_i] += input[src_i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# define SCATTER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
# define SCATTER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||||
kernel void NAME( \
|
kernel void NAME( \
|
||||||
@ -204,11 +234,13 @@ METAL_FUNC void index_add(
|
|||||||
const size_t left_rank_i = tid / right_size;
|
const size_t left_rank_i = tid / right_size;
|
||||||
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
||||||
const INDEX_TYPENAME idx = input_ids[j];
|
const INDEX_TYPENAME idx = input_ids[j];
|
||||||
|
if (idx < max_value<INDEX_TYPENAME>()) {
|
||||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||||
output[dst_i] += input[src_i];
|
output[dst_i] += input[src_i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# define INDEX_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
# define INDEX_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||||
kernel void NAME( \
|
kernel void NAME( \
|
||||||
|
Reference in New Issue
Block a user