From d4d6850c782ffe073690cfe9d33c754bc189747b Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Sat, 4 Nov 2023 08:46:08 +0100 Subject: [PATCH] Impl index_add via template for all types --- candle-metal-kernels/src/indexing.metal | 93 ++++++++++++++++--------- candle-metal-kernels/src/lib.rs | 62 +++++++---------- candle-metal-kernels/src/utils.metal | 18 +++++ 3 files changed, 103 insertions(+), 70 deletions(-) create mode 100644 candle-metal-kernels/src/utils.metal diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 0a3826d8..2c80e556 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -1,44 +1,71 @@ #include -#include -#define METAL_FUNC inline __attribute__((__always_inline__)) using namespace metal; -struct fault_counter { - uint counter; - uint tolerance; +template +void index_add( + device I *ids [[buffer(0)]], + device T *inp [[buffer(1)]], + device T *out [[buffer(2)]], - fault_counter(uint tolerance) { - this->counter = 0; - this->tolerance = tolerance; - } + constant uint &ids_dim_size, + constant uint &left_size, + constant uint &dst_dim_size, + constant uint &right_size, - bool quit() { - counter += 1; - return (counter > tolerance); - } -}; - -constant uint IDS_DIM_SIZE [[function_constant(0)]]; -constant uint SRC_DIM_SIZE [[function_constant(1)]]; // Not needed -constant uint DST_DIM_SIZE [[function_constant(2)]]; -constant uint LEFT_SIZE [[function_constant(3)]]; -constant uint RIGHT_SIZE [[function_constant(4)]]; -constant uint NUMEL = LEFT_SIZE * RIGHT_SIZE; - -kernel void index_add( - device uint *ids [[buffer(0)]], - device float *inp [[buffer(1)]], - device float *out [[buffer(2)]], + uint threadgroup_size [[threads_per_threadgroup]], + uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], uint thread_index [[thread_index_in_threadgroup]] ) { - const uint i = thread_index; - const uint pre = i / RIGHT_SIZE; - const uint post = i % RIGHT_SIZE; - for (uint j = 0; j < IDS_DIM_SIZE; ++j) { + const uint gid = thread_index + (threadgroup_position_in_grid * threadgroup_size); + if (gid >= left_size * right_size) { + return; + } + + const uint i = gid; + const uint pre = i / right_size; + const uint post = i % right_size; + + for (uint j = 0; j < ids_dim_size; j++) { const uint idx = ids[j]; - const uint src_i = (pre * IDS_DIM_SIZE + j) * RIGHT_SIZE + post; - const uint dst_i = (pre * DST_DIM_SIZE + idx) * RIGHT_SIZE + post; + const uint src_i = (pre * ids_dim_size + j) * right_size + post; + const uint dst_i = (pre * dst_dim_size + idx) * right_size + post; out[dst_i] += inp[src_i]; } -} \ No newline at end of file +} + +#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +kernel void FN_NAME( \ + device INDEX_TYPENAME *ids [[buffer(0)]], \ + device TYPENAME *inp [[buffer(1)]], \ + device TYPENAME *out [[buffer(2)]], \ + constant uint &ids_dim_size, \ + constant uint &left_size, \ + constant uint &dst_dim_size, \ + constant uint &right_size, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { index_add(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \ + +IA_OP(bfloat, int64_t, ia_i64_bf16) +IA_OP(bfloat, uint32_t, ia_u32_bf16) +IA_OP(bfloat, uint8_t, ia_u8_bf16) + +IA_OP(half, uint32_t, ia_u32_f16) +IA_OP(half, uint8_t, ia_u8_f16) + +IA_OP(float, int64_t, ia_i64_f32) +IA_OP(uint8_t, int64_t, ia_i64_u8) +IA_OP(int64_t, int64_t, ia_i64_i64) +IA_OP(uint32_t, int64_t, ia_i64_u32) + +IA_OP(float, uint32_t, ia_u32_f32) +IA_OP(uint8_t, uint32_t, ia_u32_u8) +IA_OP(int64_t, uint32_t, ia_u32_i64) +IA_OP(uint32_t, uint32_t, ia_u32_u32) + +IA_OP(float, uint8_t, ia_u8_f32) +IA_OP(uint8_t, uint8_t, ia_u8_u8) +IA_OP(uint32_t, uint8_t, ia_u8_u32) +IA_OP(int64_t, uint8_t, ia_u8_i64) \ No newline at end of file diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index d4c9cfc3..8625de3b 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -64,8 +64,8 @@ fn call_unary(_func: &Function, _input: &Buffer, _output: &Buffer, _length: usiz mod tests { use super::*; use metal::{ - CompileOptions, ComputePipelineDescriptor, Device, FunctionConstantValues, MTLDataType, - MTLResourceOptions, MTLResourceUsage, MTLSize, NSUInteger, + CompileOptions, ComputePipelineDescriptor, Device, MTLResourceOptions, MTLResourceUsage, + MTLSize, NSUInteger, }; use std::ffi::c_void; use std::mem; @@ -99,7 +99,7 @@ mod tests { let argument_encoder = func.new_argument_encoder(0); let arg_buffer = device.new_buffer( argument_encoder.encoded_length(), - MTLResourceOptions::empty(), + MTLResourceOptions::StorageModeShared, ); argument_encoder.set_argument_buffer(&arg_buffer, 0); argument_encoder.set_buffer(0, &input, 0); @@ -154,66 +154,54 @@ mod tests { let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; let right = [1.0f32; 15]; let index = [0u32, 4, 2]; - let ids_dim_size = index.len() as u32; - - // Are these reversed? - let src_dim_size: u32 = 9; let dst_dim_size: u32 = 15; let left_size: u32 = 3; let right_size: u32 = 3; - let fcv = FunctionConstantValues::new(); - fcv.set_constant_value_at_index(void_ptr(&ids_dim_size), MTLDataType::UInt, 0); - fcv.set_constant_value_at_index(void_ptr(&src_dim_size), MTLDataType::UInt, 1); - fcv.set_constant_value_at_index(void_ptr(&dst_dim_size), MTLDataType::UInt, 2); - fcv.set_constant_value_at_index(void_ptr(&left_size), MTLDataType::UInt, 3); - fcv.set_constant_value_at_index(void_ptr(&right_size), MTLDataType::UInt, 4); - - let function = library.get_function("index_add", Some(fcv)).unwrap(); + let function = library.get_function("ia_u32_f32", None).unwrap(); let pipeline = device .new_compute_pipeline_state_with_function(&function) .unwrap(); let options = MTLResourceOptions::StorageModeShared; - let ids_size = (index.len() * mem::size_of::()) as NSUInteger; - let input_size = (left.len() * mem::size_of::()) as NSUInteger; - let output_size = (right.len() * mem::size_of::()) as NSUInteger; - - let ids = device.new_buffer_with_data(void_ptr(&index), ids_size, options); - let inputs = device.new_buffer_with_data(void_ptr(&left), input_size, options); - let outputs = device.new_buffer_with_data(void_ptr(&right), output_size, options); - let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let encoder = command_buffer.new_compute_command_encoder(); + let ids_size = (index.len() * mem::size_of::()) as NSUInteger; + let input_size = (left.len() * mem::size_of::()) as NSUInteger; + let output_size = (right.len() * mem::size_of::()) as NSUInteger; + encoder.set_compute_pipeline_state(&pipeline); - let thread_group_memory_length = output_size; - encoder.set_threadgroup_memory_length(0, thread_group_memory_length as NSUInteger); + encoder.set_threadgroup_memory_length(0, output_size as NSUInteger); - encoder.use_resource(&ids, MTLResourceUsage::Read); - encoder.use_resource(&inputs, MTLResourceUsage::Read); - encoder.use_resource(&outputs, MTLResourceUsage::Write); + let index_buffer = device.new_buffer_with_data(void_ptr(&index), ids_size, options); + let inputs_buffer = device.new_buffer_with_data(void_ptr(&left), input_size, options); + let outputs_buffer = device.new_buffer_with_data(void_ptr(&right), output_size, options); - encoder.set_buffer(0, Some(&ids), 0); - encoder.set_buffer(1, Some(&inputs), 0); - encoder.set_buffer(2, Some(&outputs), 0); - let width = 16; + encoder.set_buffer(0, Some(&index_buffer), 0); + encoder.set_buffer(1, Some(&inputs_buffer), 0); + encoder.set_buffer(2, Some(&outputs_buffer), 0); - let thread_group_count = MTLSize { - width: 1, + encoder.set_bytes(3, 4, void_ptr(&ids_dim_size)); + encoder.set_bytes(4, 4, void_ptr(&left_size)); + encoder.set_bytes(5, 4, void_ptr(&dst_dim_size)); + encoder.set_bytes(6, 4, void_ptr(&right_size)); + + let grid_size = MTLSize { + width: right.len() as NSUInteger, height: 1, depth: 1, }; let thread_group_size = MTLSize { - width, + width: pipeline.max_total_threads_per_threadgroup(), height: 1, depth: 1, }; - encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.dispatch_threads(grid_size, thread_group_size); encoder.end_encoding(); command_buffer.commit(); command_buffer.wait_until_completed(); @@ -221,7 +209,7 @@ mod tests { let expected = vec![ 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0, ]; - let result = outputs.read_to_vec::(right.len()); + let result = outputs_buffer.read_to_vec::(right.len()); assert_eq!(result, expected); } } diff --git a/candle-metal-kernels/src/utils.metal b/candle-metal-kernels/src/utils.metal new file mode 100644 index 00000000..f2db8087 --- /dev/null +++ b/candle-metal-kernels/src/utils.metal @@ -0,0 +1,18 @@ +#include + +using namespace metal; + +struct fault_counter { + uint counter; + uint tolerance; + + fault_counter(uint tolerance) { + this->counter = 0; + this->tolerance = tolerance; + } + + bool quit() { + counter += 1; + return (counter > tolerance); + } +}; \ No newline at end of file