From 677495f9b8eb91e132fa1c2a3b3cb4e75afafc88 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 6 Nov 2023 17:04:47 +0100 Subject: [PATCH] Working but failing tests because of threadgroup. --- candle-metal-kernels/src/affine.metal | 2 +- candle-metal-kernels/src/lib.rs | 37 +++++++++---- candle-metal-kernels/src/unary.metal | 78 ++++++++++++++++++++++----- 3 files changed, 92 insertions(+), 25 deletions(-) diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index 4111e799..7bd98adc 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -59,4 +59,4 @@ kernel void affine( out[strided_i] = x * mul + add; } } -} \ No newline at end of file +} diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index db25e6f3..e1bdda63 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -137,6 +137,10 @@ mod tests { }; use std::mem; + fn device() -> Device { + Device::system_default().unwrap() + } + fn approx(v: Vec, digits: i32) -> Vec { let b = 10f32.powi(digits); v.iter().map(|t| f32::round(t * b) / b).collect() @@ -148,8 +152,9 @@ mod tests { } fn run_cos(v: &[T], name: &str) -> Vec { + let device = device(); let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache; - let device = Device::system_default().unwrap(); + let option = metal::MTLResourceOptions::StorageModeManaged; let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = device.new_buffer_with_data( @@ -165,18 +170,27 @@ mod tests { let pipeline_state_descriptor = ComputePipelineDescriptor::new(); pipeline_state_descriptor.set_compute_function(Some(&func)); - let pipeline_state = device + let pipeline = device .new_compute_pipeline_state_with_function( pipeline_state_descriptor.compute_function().unwrap(), ) .unwrap(); - let encoder = command_buffer.new_compute_command_encoder(); - encoder.set_compute_pipeline_state(&pipeline_state); - encoder.set_buffer(0, Some(&input), 0); - encoder.set_buffer(1, Some(&output), 0); + let dim: u32 = v.len() as u32; + // let num_dims: u32 = 1; + // let info = [v.len() as u32, 1]; - let width = 16; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + encoder.set_bytes(0, 4, void_ptr(&dim)); + // encoder.set_bytes(1, 4, void_ptr(&num_dims)); + // encoder.set_bytes(2, 4, void_ptr(&info)); + + encoder.set_buffer(1, Some(&input), 0); + encoder.set_buffer(2, Some(&output), 0); + + let width = v.len() as NSUInteger; let thread_group_count = MTLSize { width, @@ -185,7 +199,7 @@ mod tests { }; let thread_group_size = MTLSize { - width: (v.len() as u64 + width) / width, + width: pipeline.max_total_threads_per_threadgroup(), height: 1, depth: 1, }; @@ -208,7 +222,7 @@ mod tests { #[test] fn affine() { - let device = Device::system_default().expect("no device found"); + let device = device(); let options = CompileOptions::new(); let library = device.new_library_with_source(AFFINE, &options).unwrap(); @@ -225,7 +239,8 @@ mod tests { let pipeline = device .new_compute_pipeline_state_with_function(&function) .unwrap(); - let options = MTLResourceOptions::StorageModeShared; + // let options = MTLResourceOptions::StorageModeShared; + let options = metal::MTLResourceOptions::StorageModeManaged; let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); @@ -291,7 +306,7 @@ mod tests { let pipeline = device .new_compute_pipeline_state_with_function(&function) .unwrap(); - let options = MTLResourceOptions::StorageModeShared; + let options = metal::MTLResourceOptions::StorageModeManaged; let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index a2635b68..b8056909 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -1,22 +1,74 @@ #include +# +METAL_FUNC bool is_contiguous( + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + size_t acc = 1; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + if (acc != strides[dim_idx]) { + return false; + } + acc *= dims[dim_idx]; + } + return true; +} + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + using namespace metal; -template -kernel void unary_cos(device const T *input, device T *output, uint index [[thread_position_in_grid]]) -{ - output[index] = cos(input[index]); +#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const uint i = thread_index + (threadgroup_position_in_grid * threadgroup_size); \ + output[i] = FN(input[i]); \ +}\ +kernel void FN_NAME_STRIDED( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *info, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + constant size_t *dims = info; \ + constant size_t *strides = info + num_dims; \ + const uint start = thread_index + (threadgroup_position_in_grid * threadgroup_size); \ + const uint stop = min(thread_index + (threadgroup_position_in_grid * threadgroup_size), (uint) dim); \ + for (size_t i = start; i < stop; i++) { \ + output[i] = FN(input[get_strided_index(i, num_dims, dims, strides)]); \ + output[i] = 1; \ + } \ } -#define UNARY(FN, TYPENAME, FN_NAME) \ -kernel void FN_NAME(device const TYPENAME *input, device TYPENAME *output, uint index [[thread_position_in_grid]]) \ -{ \ - output[index] = FN(input[index]);\ -} - -UNARY(cos, float, cos_float); -UNARY(cos, half, cos_half); +UNARY(cos, float, cos_float, cos_float_strided); +UNARY(cos, half, cos_half, cos_half_strided); #if __METAL_VERSION__ >= 310 -UNARY(cos, half, cos_half); +UNARY(cos, bfloat, cos_bfloat, cos_bfloat_strided); #endif