From 6a3ca7da0cfb06e80d5c75ee98a1291843092e06 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 18 Dec 2023 10:32:22 +0100 Subject: [PATCH] Scatter add. --- candle-core/src/metal_backend.rs | 60 ++++++++++++++++++++----- candle-metal-kernels/src/indexing.metal | 46 ++++++++++++++++--- candle-metal-kernels/src/lib.rs | 58 +++++++++++++++++++++++- 3 files changed, 147 insertions(+), 17 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 227bcfb0..b26477fc 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -45,6 +45,12 @@ pub enum MetalError { }, #[error("{0:?}")] LockError(LockError), + #[error("{msg}, expected: {expected:?}, got: {got:?}")] + UnexpectedDType { + msg: &'static str, + expected: DType, + got: DType, + }, } impl From for MetalError { @@ -827,12 +833,10 @@ impl BackendStorage for MetalStorage { } fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result { - let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() { + let (ids_o1, _) = match ids_l.contiguous_offsets() { Some(o12) => o12, None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?, }; - let left_size: usize = src_l.dims()[..dim].iter().product(); - let right_size: usize = src_l.dims()[dim + 1..].iter().product(); let ids_el = ids_l.dims()[dim]; let dst_el = ids_l.shape().elem_count(); let dtype = self.dtype; @@ -853,7 +857,9 @@ impl BackendStorage for MetalStorage { ids_el, dim, &self.buffer, + src_l.start_offset() * dtype.size_in_bytes(), &ids.buffer, + ids_o1 * ids.dtype.size_in_bytes(), &buffer, ) .map_err(MetalError::from)?; @@ -862,14 +868,48 @@ impl BackendStorage for MetalStorage { fn scatter_add( &self, - _: &Layout, - _: &Self, - _: &Layout, - _: &Self, - _: &Layout, - _: usize, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, ) -> Result { - crate::bail!("scatter_add metal") + let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?; + self.copy_strided_src(&mut acc, 0, l)?; + let (ids_offset, _) = match ids_l.contiguous_offsets() { + Some(o12) => o12, + None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, + }; + let src_offset = match src_l.contiguous_offsets() { + Some((o1, _)) => o1, + None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, + }; + let name = match (ids.dtype, self.dtype) { + (DType::U32, DType::F32) => "sa_u32_f32", + _ => Err(MetalError::UnexpectedDType { + msg: "scatter-add ids should be u8/u32/i64", + expected: DType::U32, + got: ids.dtype(), + })?, + }; + let command_buffer = self.device.command_buffer()?; + candle_metal_kernels::call_scatter_add( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + l.dims(), + dim, + &src.buffer, + src_offset * src.dtype.size_in_bytes(), + &ids.buffer, + ids_offset * ids.dtype.size_in_bytes(), + &acc.buffer, + ) + .map_err(MetalError::from)?; + Ok(acc) } fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 96adb4c4..72a3a348 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -63,11 +63,6 @@ METAL_FUNC void gather( const INDEX_TYPENAME input_i = input_ids[tid]; const size_t right_rank_i = tid % right_size; const size_t left_rank_i = tid / right_size / ids_size; - /* - // Force prevent out of bounds indexing - // since there doesn't seem to be a good way to force crash - // No need to check for zero we're only allowing unsized. - */ const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i; output[tid] = input[src_i]; } @@ -87,6 +82,45 @@ kernel void NAME( \ gather(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \ } +template +METAL_FUNC void scatter_add( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &dst_dim_size, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size; + for (unsigned int j = 0; j < src_dim_size; ++j) { + const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i; + const INDEX_TYPENAME idx = input_ids[src_i]; + const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; + output[dst_i] += input[src_i]; + } +} + +# define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &dst_size, \ + constant size_t &left_size, \ + constant size_t &src_dim_size, \ + constant size_t &right_size, \ + constant size_t &dst_dim_size, \ + const device TYPENAME *input, \ + const device INDEX_TYPENAME *input_ids, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + scatter_add(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \ +} template @@ -136,6 +170,8 @@ INDEX_OP(is_u32_f32, uint, float) INDEX_OP(is_u32_f16, uint, half) GATHER_OP(gather_u32_f32, uint, float) GATHER_OP(gather_u32_f16, uint, half) +SCATTER_ADD_OP(sa_u32_f32, uint, float) +SCATTER_ADD_OP(sa_u32_f16, uint, half) #if __METAL_VERSION__ >= 310 diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 45929aa3..ddc04d05 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1020,7 +1020,9 @@ pub fn call_gather( ids_size: usize, dim: usize, input: &Buffer, + input_offset: usize, ids: &Buffer, + ids_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); @@ -1043,8 +1045,60 @@ pub fn call_gather( src_dim_size, right_size, ids_size, - input, - ids, + (input, input_offset), + (ids, ids_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_scatter_add( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + src_shape: &[usize], + dst_shape: &[usize], + dim: usize, + input: &Buffer, + input_offset: usize, + ids: &Buffer, + ids_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let left_size: usize = src_shape[..dim].iter().product(); + let right_size: usize = src_shape[dim + 1..].iter().product(); + let src_dim_size = src_shape[dim]; + let dst_el = left_size * right_size; + let dst_dim_size = dst_shape[dim]; + + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; + + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + dst_dim_size, + (input, input_offset), + (ids, ids_offset), output ) );