mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Index add.
This commit is contained in:
@ -951,14 +951,49 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
fn index_add(
|
fn index_add(
|
||||||
&self,
|
&self,
|
||||||
_: &Layout,
|
l: &Layout,
|
||||||
_: &Self,
|
ids: &Self,
|
||||||
_: &Layout,
|
ids_l: &Layout,
|
||||||
_: &Self,
|
src: &Self,
|
||||||
_: &Layout,
|
src_l: &Layout,
|
||||||
_: usize,
|
dim: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
crate::bail!("index_add metal")
|
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
|
||||||
|
self.copy_strided_src(&mut acc, 0, l)?;
|
||||||
|
let (ids_offset, _) = match ids_l.contiguous_offsets() {
|
||||||
|
Some(o12) => o12,
|
||||||
|
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||||
|
};
|
||||||
|
let src_offset = match src_l.contiguous_offsets() {
|
||||||
|
Some((o1, _)) => o1,
|
||||||
|
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||||
|
};
|
||||||
|
let name = match (ids.dtype, self.dtype) {
|
||||||
|
(DType::U32, DType::F32) => "ia_u32_f32",
|
||||||
|
_ => Err(MetalError::UnexpectedDType {
|
||||||
|
msg: "index-add ids should be u8/u32/i64",
|
||||||
|
expected: DType::U32,
|
||||||
|
got: ids.dtype(),
|
||||||
|
})?,
|
||||||
|
};
|
||||||
|
let command_buffer = self.device.command_buffer()?;
|
||||||
|
candle_metal_kernels::call_index_add(
|
||||||
|
&self.device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&self.device.kernels,
|
||||||
|
name,
|
||||||
|
src_l.dims(),
|
||||||
|
l.dims(),
|
||||||
|
ids_l.dims(),
|
||||||
|
dim,
|
||||||
|
&src.buffer,
|
||||||
|
src_offset * src.dtype.size_in_bytes(),
|
||||||
|
&ids.buffer,
|
||||||
|
ids_offset * ids.dtype.size_in_bytes(),
|
||||||
|
&acc.buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
Ok(acc)
|
||||||
}
|
}
|
||||||
fn matmul(
|
fn matmul(
|
||||||
&self,
|
&self,
|
||||||
|
@ -122,48 +122,47 @@ kernel void NAME( \
|
|||||||
scatter_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \
|
scatter_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||||
template <typename T, typename I>
|
METAL_FUNC void index_add(
|
||||||
void index_add(
|
constant size_t &dst_size,
|
||||||
device I *ids [[buffer(0)]],
|
constant size_t &left_size,
|
||||||
device T *inp [[buffer(1)]],
|
constant size_t &src_dim_size,
|
||||||
device T *out [[buffer(2)]],
|
constant size_t &right_size,
|
||||||
|
constant size_t &dst_dim_size,
|
||||||
constant uint &ids_dim_size,
|
constant size_t &ids_dim_size,
|
||||||
constant uint &left_size,
|
const device TYPENAME *input,
|
||||||
constant uint &dst_dim_size,
|
const device INDEX_TYPENAME *input_ids,
|
||||||
constant uint &right_size,
|
device TYPENAME *output,
|
||||||
|
uint tid [[ thread_position_in_grid ]]
|
||||||
uint gid [[ thread_position_in_grid ]] \
|
|
||||||
) {
|
) {
|
||||||
|
if (tid >= dst_size) {
|
||||||
if (gid >= left_size * right_size) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
const size_t right_rank_i = tid % right_size;
|
||||||
const uint i = gid;
|
const size_t left_rank_i = tid / right_size;
|
||||||
const uint pre = i / right_size;
|
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
||||||
const uint post = i % right_size;
|
const INDEX_TYPENAME idx = input_ids[j];
|
||||||
|
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||||
for (uint j = 0; j < ids_dim_size; j++) {
|
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||||
const uint idx = ids[j];
|
output[dst_i] += input[src_i];
|
||||||
const uint src_i = (pre * ids_dim_size + j) * right_size + post;
|
|
||||||
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
|
||||||
out[dst_i] += inp[src_i];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
# define INDEX_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||||
kernel void FN_NAME( \
|
kernel void NAME( \
|
||||||
device INDEX_TYPENAME *ids [[buffer(0)]], \
|
constant size_t &dst_size, \
|
||||||
device TYPENAME *inp [[buffer(1)]], \
|
constant size_t &left_size, \
|
||||||
device TYPENAME *out [[buffer(2)]], \
|
constant size_t &src_dim_size, \
|
||||||
constant uint &ids_dim_size, \
|
constant size_t &right_size, \
|
||||||
constant uint &left_size, \
|
constant size_t &dst_dim_size, \
|
||||||
constant uint &dst_dim_size, \
|
constant size_t &ids_dim_size, \
|
||||||
constant uint &right_size, \
|
const device TYPENAME *input, \
|
||||||
uint gid [[ thread_position_in_grid ]] \
|
const device INDEX_TYPENAME *input_ids, \
|
||||||
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \
|
device TYPENAME *output, \
|
||||||
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
index_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, ids_dim_size, input, input_ids, output, tid); \
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
INDEX_OP(is_u32_f32, uint, float)
|
INDEX_OP(is_u32_f32, uint, float)
|
||||||
@ -175,25 +174,25 @@ SCATTER_ADD_OP(sa_u32_f16, uint, half)
|
|||||||
|
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
IA_OP(bfloat, int64_t, ia_i64_bf16)
|
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
|
||||||
IA_OP(bfloat, uint32_t, ia_u32_bf16)
|
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
|
||||||
IA_OP(bfloat, uint8_t, ia_u8_bf16)
|
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
IA_OP(half, uint32_t, ia_u32_f16)
|
INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
|
||||||
IA_OP(half, uint8_t, ia_u8_f16)
|
INDEX_ADD_OP(ia_u8_f16, uint8_t, half)
|
||||||
|
|
||||||
IA_OP(float, int64_t, ia_i64_f32)
|
INDEX_ADD_OP(ia_i64_f32, int64_t, float)
|
||||||
IA_OP(uint8_t, int64_t, ia_i64_u8)
|
INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t)
|
||||||
IA_OP(int64_t, int64_t, ia_i64_i64)
|
INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t)
|
||||||
IA_OP(uint32_t, int64_t, ia_i64_u32)
|
INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t)
|
||||||
|
|
||||||
IA_OP(float, uint32_t, ia_u32_f32)
|
INDEX_ADD_OP(ia_u32_f32, uint32_t, float)
|
||||||
IA_OP(uint8_t, uint32_t, ia_u32_u8)
|
INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t)
|
||||||
IA_OP(int64_t, uint32_t, ia_u32_i64)
|
INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t)
|
||||||
IA_OP(uint32_t, uint32_t, ia_u32_u32)
|
INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t)
|
||||||
|
|
||||||
IA_OP(float, uint8_t, ia_u8_f32)
|
INDEX_ADD_OP(ia_u8_f32, uint8_t, float)
|
||||||
IA_OP(uint8_t, uint8_t, ia_u8_u8)
|
INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
|
||||||
IA_OP(uint32_t, uint8_t, ia_u8_u32)
|
INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
|
||||||
IA_OP(int64_t, uint8_t, ia_u8_i64)
|
INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t)
|
||||||
|
@ -1114,6 +1114,60 @@ pub fn call_scatter_add(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn call_index_add(
|
||||||
|
device: &Device,
|
||||||
|
command_buffer: &CommandBufferRef,
|
||||||
|
kernels: &Kernels,
|
||||||
|
name: &'static str,
|
||||||
|
src_shape: &[usize],
|
||||||
|
dst_shape: &[usize],
|
||||||
|
ids_shape: &[usize],
|
||||||
|
dim: usize,
|
||||||
|
input: &Buffer,
|
||||||
|
input_offset: usize,
|
||||||
|
ids: &Buffer,
|
||||||
|
ids_offset: usize,
|
||||||
|
output: &Buffer,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let left_size: usize = src_shape[..dim].iter().product();
|
||||||
|
let right_size: usize = src_shape[dim + 1..].iter().product();
|
||||||
|
let src_dim_size = src_shape[dim];
|
||||||
|
let dst_el = left_size * right_size;
|
||||||
|
let dst_dim_size = dst_shape[dim];
|
||||||
|
let ids_dim_size = ids_shape[0];
|
||||||
|
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
|
set_params!(
|
||||||
|
encoder,
|
||||||
|
(
|
||||||
|
dst_el,
|
||||||
|
left_size,
|
||||||
|
src_dim_size,
|
||||||
|
right_size,
|
||||||
|
dst_dim_size,
|
||||||
|
ids_dim_size,
|
||||||
|
(input, input_offset),
|
||||||
|
(ids, ids_offset),
|
||||||
|
output
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
|
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
|
encoder.end_encoding();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq)]
|
#[derive(Debug, PartialEq)]
|
||||||
pub enum Value {
|
pub enum Value {
|
||||||
USize(usize),
|
USize(usize),
|
||||||
|
Reference in New Issue
Block a user