From 586b6f6fff01f02cf5275f9ede47a0fe10206210 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sun, 17 Dec 2023 23:34:12 +0100 Subject: [PATCH] Adding gather op. --- candle-core/src/metal_backend.rs | 34 +++++++++- candle-metal-kernels/src/indexing.metal | 90 ++++++++++++++++++++----- candle-metal-kernels/src/lib.rs | 50 ++++++++++++++ 3 files changed, 157 insertions(+), 17 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 6f82b0cc..227bcfb0 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -826,8 +826,38 @@ impl BackendStorage for MetalStorage { crate::bail!("upsample_nearest2d metal") } - fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result { - crate::bail!("gather metal") + fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result { + let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() { + Some(o12) => o12, + None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?, + }; + let left_size: usize = src_l.dims()[..dim].iter().product(); + let right_size: usize = src_l.dims()[dim + 1..].iter().product(); + let ids_el = ids_l.dims()[dim]; + let dst_el = ids_l.shape().elem_count(); + let dtype = self.dtype; + let device = self.device(); + let buffer = device.new_buffer(dst_el, dtype, "index_select")?; + let name = match (ids.dtype, self.dtype) { + (DType::U32, DType::F32) => "gather_u32_f32", + (DType::U32, DType::F16) => "gather_u32_f16", + (left, right) => crate::bail!("gather metal {left:?} {right:?} not implemented"), + }; + let command_buffer = self.device.command_buffer()?; + candle_metal_kernels::call_gather( + &device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + ids_el, + dim, + &self.buffer, + &ids.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, device.clone(), dtype)) } fn scatter_add( diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 312b27c7..96adb4c4 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -1,6 +1,34 @@ #include using namespace metal; +template +METAL_FUNC void index( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &ids_size, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const size_t id_i = (tid / right_size) % ids_size; + const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size / ids_size; + /* + // Force prevent out of bounds indexing + // since there doesn't seem to be a good way to force crash + // No need to check for zero we're only allowing unsized. + */ + const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; + output[tid] = input[src_i]; +} + # define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \ kernel void NAME( \ constant size_t &dst_size, \ @@ -11,22 +39,52 @@ kernel void NAME( \ const device TYPENAME *input, \ const device INDEX_TYPENAME *input_ids, \ device TYPENAME *output, \ - uint gid [[ thread_position_in_grid ]] \ + uint tid [[ thread_position_in_grid ]] \ ) { \ - if (gid >= dst_size) { \ - return; \ - } \ - const size_t id_i = (gid / right_size) % ids_size; \ - const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \ - const size_t right_rank_i = gid % right_size; \ - const size_t left_rank_i = gid / right_size / ids_size; \ - /* \ - // Force prevent out of bounds indexing \ - // since there doesn't seem to be a good way to force crash \ - // No need to check for zero we're only allowing unsized. \ - */ \ - const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \ - output[gid] = input[src_i]; \ + index(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \ +} + + +template +METAL_FUNC void gather( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &ids_size, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const INDEX_TYPENAME input_i = input_ids[tid]; + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size / ids_size; + /* + // Force prevent out of bounds indexing + // since there doesn't seem to be a good way to force crash + // No need to check for zero we're only allowing unsized. + */ + const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i; + output[tid] = input[src_i]; +} + +# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &dst_size, \ + constant size_t &left_size, \ + constant size_t &src_dim_size, \ + constant size_t &right_size, \ + constant size_t &ids_size, \ + const device TYPENAME *input, \ + const device INDEX_TYPENAME *input_ids, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + gather(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \ } @@ -76,6 +134,8 @@ kernel void FN_NAME( \ INDEX_OP(is_u32_f32, uint, float) INDEX_OP(is_u32_f16, uint, half) +GATHER_OP(gather_u32_f32, uint, float) +GATHER_OP(gather_u32_f16, uint, half) #if __METAL_VERSION__ >= 310 diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 7485ba72..45929aa3 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1010,6 +1010,56 @@ pub fn call_index_select( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_gather( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + ids_size: usize, + dim: usize, + input: &Buffer, + ids: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let src_dim_size = shape[dim]; + let dst_el = ids_size * left_size * right_size; + + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; + + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + ids_size, + input, + ids, + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + #[derive(Debug, PartialEq)] pub enum Value { USize(usize),