Adding gather op.

This commit is contained in:
Nicolas Patry
2023-12-17 23:34:12 +01:00
parent e4b0cc59f5
commit 586b6f6fff
3 changed files with 157 additions and 17 deletions

View File

@ -826,8 +826,38 @@ impl BackendStorage for MetalStorage {
crate::bail!("upsample_nearest2d metal")
}
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
crate::bail!("gather metal")
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
};
let left_size: usize = src_l.dims()[..dim].iter().product();
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
let ids_el = ids_l.dims()[dim];
let dst_el = ids_l.shape().elem_count();
let dtype = self.dtype;
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "gather_u32_f32",
(DType::U32, DType::F16) => "gather_u32_f16",
(left, right) => crate::bail!("gather metal {left:?} {right:?} not implemented"),
};
let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_gather(
&device.device,
&command_buffer,
&self.device.kernels,
name,
src_l.dims(),
ids_el,
dim,
&self.buffer,
&ids.buffer,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, device.clone(), dtype))
}
fn scatter_add(

View File

@ -1,6 +1,34 @@
#include <metal_stdlib>
using namespace metal;
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void index(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &ids_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const size_t id_i = (tid / right_size) % ids_size;
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
/*
// Force prevent out of bounds indexing
// since there doesn't seem to be a good way to force crash
// No need to check for zero we're only allowing unsized.
*/
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
output[tid] = input[src_i];
}
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
@ -11,22 +39,52 @@ kernel void NAME( \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint gid [[ thread_position_in_grid ]] \
uint tid [[ thread_position_in_grid ]] \
) { \
if (gid >= dst_size) { \
return; \
} \
const size_t id_i = (gid / right_size) % ids_size; \
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
const size_t right_rank_i = gid % right_size; \
const size_t left_rank_i = gid / right_size / ids_size; \
/* \
// Force prevent out of bounds indexing \
// since there doesn't seem to be a good way to force crash \
// No need to check for zero we're only allowing unsized. \
*/ \
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
output[gid] = input[src_i]; \
index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
}
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void gather(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &ids_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const INDEX_TYPENAME input_i = input_ids[tid];
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
/*
// Force prevent out of bounds indexing
// since there doesn't seem to be a good way to force crash
// No need to check for zero we're only allowing unsized.
*/
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
output[tid] = input[src_i];
}
# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
constant size_t &left_size, \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &ids_size, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
}
@ -76,6 +134,8 @@ kernel void FN_NAME( \
INDEX_OP(is_u32_f32, uint, float)
INDEX_OP(is_u32_f16, uint, half)
GATHER_OP(gather_u32_f32, uint, float)
GATHER_OP(gather_u32_f16, uint, half)
#if __METAL_VERSION__ >= 310

View File

@ -1010,6 +1010,56 @@ pub fn call_index_select(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_gather(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
ids_size: usize,
dim: usize,
input: &Buffer,
ids: &Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let src_dim_size = shape[dim];
let dst_el = ids_size * left_size * right_size;
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
dst_el,
left_size,
src_dim_size,
right_size,
ids_size,
input,
ids,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
#[derive(Debug, PartialEq)]
pub enum Value {
USize(usize),