From 289c57d600e938027716cc8a271336a357f0551c Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Thu, 28 Dec 2023 17:31:07 +0100 Subject: [PATCH] Removing metal fences. Increases performance substantially on m1 pro. --- candle-core/src/metal_backend.rs | 2 +- candle-metal-kernels/src/lib.rs | 66 +++++-------------------------- candle-metal-kernels/src/tests.rs | 2 +- 3 files changed, 12 insertions(+), 58 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 6d8afab1..ca2b464d 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -1367,7 +1367,7 @@ impl BackendDevice for MetalDevice { let command_buffer = Arc::new(RwLock::new(command_buffer)); let command_buffer_index = Arc::new(RwLock::new(0)); let fence = device.new_fence(); - let kernels = Arc::new(Kernels::new(fence.clone())); + let kernels = Arc::new(Kernels::new()); let buffers = Arc::new(RwLock::new(HashMap::new())); let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { Ok(val) => val.parse()?, diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index dd97a86d..3119c1e8 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -204,17 +204,15 @@ type Pipelines = HashMap<(&'static str, Option), ComputePipeline pub struct Kernels { libraries: RwLock, pipelines: RwLock, - fence: metal::Fence, } impl Kernels { - pub fn new(fence: metal::Fence) -> Self { + pub fn new() -> Self { let libraries = RwLock::new(Libraries::new()); let pipelines = RwLock::new(Pipelines::new()); Self { libraries, pipelines, - fence, } } @@ -334,7 +332,6 @@ pub fn call_unary_contiguous( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, input, output)); @@ -343,7 +340,6 @@ pub fn call_unary_contiguous( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -365,7 +361,6 @@ pub fn call_unary_strided( let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -387,7 +382,6 @@ pub fn call_unary_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -406,7 +400,6 @@ pub fn call_binary_contiguous( let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, left, right, output)); @@ -417,7 +410,6 @@ pub fn call_binary_contiguous( encoder.use_resource(right, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -442,7 +434,6 @@ pub fn call_binary_strided( let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); let width: usize = shape.iter().product(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -467,7 +458,6 @@ pub fn call_binary_strided( encoder.use_resource(right_input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -486,7 +476,6 @@ pub fn call_cast_contiguous( let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, (input, input_offset), output)); @@ -495,7 +484,6 @@ pub fn call_cast_contiguous( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -515,7 +503,6 @@ pub fn call_cast_strided( let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -537,7 +524,6 @@ pub fn call_cast_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -557,7 +543,6 @@ pub fn call_reduce_contiguous( let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -586,7 +571,6 @@ pub fn call_reduce_contiguous( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -608,7 +592,6 @@ pub fn call_reduce_strided( let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -644,7 +627,6 @@ pub fn call_reduce_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -663,7 +645,6 @@ pub fn call_last_softmax( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -694,7 +675,6 @@ pub fn call_last_softmax( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -714,7 +694,6 @@ pub fn call_affine( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, add, input, output)); @@ -723,7 +702,6 @@ pub fn call_affine( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -746,7 +724,6 @@ pub fn call_affine_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -767,8 +744,8 @@ pub fn call_affine_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); + Ok(()) } @@ -786,7 +763,6 @@ pub fn call_powf( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, input, output)); @@ -795,8 +771,8 @@ pub fn call_powf( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); + Ok(()) } @@ -817,7 +793,6 @@ pub fn call_powf_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -837,7 +812,6 @@ pub fn call_powf_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -856,7 +830,6 @@ pub fn call_elu( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, input, output)); @@ -865,7 +838,6 @@ pub fn call_elu( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -887,7 +859,6 @@ pub fn call_elu_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -907,7 +878,6 @@ pub fn call_elu_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -929,7 +899,6 @@ pub fn call_where_cond_strided( let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let size: usize = shape.iter().product(); @@ -958,7 +927,6 @@ pub fn call_where_cond_strided( encoder.use_resource(right, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -984,8 +952,6 @@ pub fn call_index_select( let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = command_buffer.new_compute_command_encoder(); - - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1008,8 +974,8 @@ pub fn call_index_select( encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); + Ok(()) } @@ -1036,8 +1002,6 @@ pub fn call_gather( let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = command_buffer.new_compute_command_encoder(); - - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1060,8 +1024,8 @@ pub fn call_gather( encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); + Ok(()) } @@ -1088,8 +1052,6 @@ pub fn call_scatter_add( let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = command_buffer.new_compute_command_encoder(); - - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1112,8 +1074,8 @@ pub fn call_scatter_add( encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); + Ok(()) } @@ -1141,8 +1103,6 @@ pub fn call_index_add( let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = command_buffer.new_compute_command_encoder(); - - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1166,8 +1126,8 @@ pub fn call_index_add( encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); + Ok(()) } @@ -1370,7 +1330,6 @@ pub fn call_gemm( let block_bytes = block_elements * bytes; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); encoder.set_threadgroup_memory_length(0, block_bytes.into()); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); @@ -1415,7 +1374,6 @@ pub fn call_gemm( encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_size, group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) @@ -1440,8 +1398,8 @@ pub fn call_im2col1d_strided( let encoder = command_buffer.new_compute_command_encoder(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); + set_params!( encoder, ( @@ -1460,7 +1418,6 @@ pub fn call_im2col1d_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) @@ -1490,8 +1447,8 @@ pub fn call_im2col_strided( let encoder = command_buffer.new_compute_command_encoder(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); + set_params!( encoder, ( @@ -1512,9 +1469,7 @@ pub fn call_im2col_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); - Ok(()) } @@ -1538,8 +1493,8 @@ pub fn call_upsample_nearest_2d( let scale_h = shape[3] as f32 / out_h as f32; let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); + set_params!( encoder, ( @@ -1556,7 +1511,6 @@ pub fn call_upsample_nearest_2d( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index c955abca..312032b4 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -38,7 +38,7 @@ fn approx_bf16(v: Vec, digits: i32) -> Vec { fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { let device = device(); let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v);