From 3aefc709c776b6989674a9d1867d362874cc5c44 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 5 Jan 2024 21:57:07 +0100 Subject: [PATCH] Cleanup the fence. --- candle-core/src/metal_backend.rs | 19 +---------- candle-core/src/quantized/metal.rs | 2 -- candle-metal-kernels/src/lib.rs | 52 ------------------------------ candle-metal-kernels/src/tests.rs | 33 +++++++------------ 4 files changed, 12 insertions(+), 94 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index ac37a7ce..dc790ac9 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -84,13 +84,8 @@ pub struct MetalDevice { command_buffer_index: Arc>, /// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc) compute_per_buffer: usize, - /// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the - /// execution order to be linear. - /// It could be relaxed in some circumstances, by managing ourselves the dependencies in the - /// compute graph. - // fence: metal::Fence, /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. - /// Heavily used by [`candle_metal_kernels`], both fences need to match + /// Heavily used by [`candle_metal_kernels`] kernels: Arc, /// Simple allocator struct. /// The buffers are stored in size buckets since ML tends to use similar shapes over and over. @@ -131,10 +126,6 @@ impl MetalDevice { &self.device } - // pub(crate) fn fence(&self) -> &metal::Fence { - // &self.fence - // } - pub fn command_queue(&self) -> &CommandQueue { &self.command_queue } @@ -225,10 +216,8 @@ impl MetalDevice { let command_buffer = self.command_buffer()?; command_buffer.set_label("with_data"); let blit = command_buffer.new_blit_command_encoder(); - // blit.wait_for_fence(&self.fence); blit.set_label("with_data_blit"); blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); - // blit.update_fence(&self.fence); blit.end_encoding(); // This is necessary, for mmaped safetensors @@ -251,7 +240,6 @@ impl MetalDevice { let command_buffer = self.command_buffer()?; command_buffer.set_label("zeros"); let blit = command_buffer.new_blit_command_encoder(); - // blit.wait_for_fence(&self.fence); blit.fill_buffer( &buffer, metal::NSRange { @@ -260,7 +248,6 @@ impl MetalDevice { }, 0, ); - // blit.update_fence(&self.fence); blit.end_encoding(); Ok(buffer) } @@ -1543,9 +1530,7 @@ impl MetalStorage { command_buffer.set_label("to_cpu"); let blit = command_buffer.new_blit_command_encoder(); blit.set_label("blit_to_cpu"); - // blit.wait_for_fence(&self.device.fence); blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); - // blit.update_fence(&self.device.fence); blit.end_encoding(); } self.device.wait_until_completed()?; @@ -1563,7 +1548,6 @@ impl BackendDevice for MetalDevice { command_buffer.enqueue(); let command_buffer = Arc::new(RwLock::new(command_buffer)); let command_buffer_index = Arc::new(RwLock::new(0)); - // let fence = device.new_fence(); let kernels = Arc::new(Kernels::new()); let buffers = Arc::new(RwLock::new(HashMap::new())); let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { @@ -1572,7 +1556,6 @@ impl BackendDevice for MetalDevice { }; Ok(Self { device, - // fence, command_queue, command_buffer, command_buffer_index, diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index f435abb0..fe57ce14 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -32,9 +32,7 @@ impl QMetalStorage { command_buffer.set_label("to_cpu"); let blit = command_buffer.new_blit_command_encoder(); blit.set_label("blit_to_cpu"); - // blit.wait_for_fence(&self.device.fence()); blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); - // blit.update_fence(&self.device.fence()); blit.end_encoding(); self.device.wait_until_completed()?; let mut out = vec![0.0; elem_count]; diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 626e1c3d..201af97e 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -219,7 +219,6 @@ type Pipelines = HashMap<(&'static str, Option), ComputePipeline pub struct Kernels { libraries: RwLock, pipelines: RwLock, - // fence: metal::Fence, } impl Kernels { @@ -229,7 +228,6 @@ impl Kernels { Self { libraries, pipelines, - // fence, } } @@ -350,7 +348,6 @@ pub fn call_unary_contiguous( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, input, output)); @@ -359,7 +356,6 @@ pub fn call_unary_contiguous( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -381,7 +377,6 @@ pub fn call_unary_strided( let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -403,7 +398,6 @@ pub fn call_unary_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -422,7 +416,6 @@ pub fn call_binary_contiguous( let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, left, right, output)); @@ -433,7 +426,6 @@ pub fn call_binary_contiguous( encoder.use_resource(right, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -458,7 +450,6 @@ pub fn call_binary_strided( let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); let width: usize = shape.iter().product(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -483,7 +474,6 @@ pub fn call_binary_strided( encoder.use_resource(right_input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -502,7 +492,6 @@ pub fn call_cast_contiguous( let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, (input, input_offset), output)); @@ -511,7 +500,6 @@ pub fn call_cast_contiguous( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -531,7 +519,6 @@ pub fn call_cast_strided( let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -553,7 +540,6 @@ pub fn call_cast_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -573,7 +559,6 @@ pub fn call_reduce_contiguous( let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -602,7 +587,6 @@ pub fn call_reduce_contiguous( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -624,7 +608,6 @@ pub fn call_reduce_strided( let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -660,7 +643,6 @@ pub fn call_reduce_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -679,7 +661,6 @@ pub fn call_last_softmax( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -710,7 +691,6 @@ pub fn call_last_softmax( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -730,7 +710,6 @@ pub fn call_affine( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, add, input, output)); @@ -739,7 +718,6 @@ pub fn call_affine( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -762,7 +740,6 @@ pub fn call_affine_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -783,7 +760,6 @@ pub fn call_affine_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -802,7 +778,6 @@ pub fn call_powf( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, input, output)); @@ -811,7 +786,6 @@ pub fn call_powf( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -833,7 +807,6 @@ pub fn call_powf_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -853,7 +826,6 @@ pub fn call_powf_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -872,7 +844,6 @@ pub fn call_elu( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, input, output)); @@ -881,7 +852,6 @@ pub fn call_elu( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -903,7 +873,6 @@ pub fn call_elu_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -923,7 +892,6 @@ pub fn call_elu_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -945,7 +913,6 @@ pub fn call_where_cond_strided( let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let size: usize = shape.iter().product(); @@ -974,7 +941,6 @@ pub fn call_where_cond_strided( encoder.use_resource(right, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -1001,7 +967,6 @@ pub fn call_index_select( let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1024,7 +989,6 @@ pub fn call_index_select( encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -1053,7 +1017,6 @@ pub fn call_gather( let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1076,7 +1039,6 @@ pub fn call_gather( encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -1105,7 +1067,6 @@ pub fn call_scatter_add( let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1128,7 +1089,6 @@ pub fn call_scatter_add( encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -1158,7 +1118,6 @@ pub fn call_index_add( let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1182,7 +1141,6 @@ pub fn call_index_add( encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -1386,7 +1344,6 @@ pub fn call_gemm( let block_bytes = block_elements * bytes; let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); encoder.set_threadgroup_memory_length(0, block_bytes.into()); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); @@ -1430,7 +1387,6 @@ pub fn call_gemm( encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_size, group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) @@ -1455,7 +1411,6 @@ pub fn call_im2col1d_strided( let encoder = command_buffer.new_compute_command_encoder(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, @@ -1475,7 +1430,6 @@ pub fn call_im2col1d_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) @@ -1505,7 +1459,6 @@ pub fn call_im2col_strided( let encoder = command_buffer.new_compute_command_encoder(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, @@ -1527,7 +1480,6 @@ pub fn call_im2col_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) @@ -1553,7 +1505,6 @@ pub fn call_upsample_nearest_2d( let scale_h = shape[3] as f32 / out_h as f32; let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, @@ -1571,7 +1522,6 @@ pub fn call_upsample_nearest_2d( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) @@ -1710,7 +1660,6 @@ pub fn call_quantized_matmul_t( let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; let encoder = command_buffer.new_compute_command_encoder(); - //encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1743,7 +1692,6 @@ pub fn call_quantized_matmul_t( encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup); - //encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 87f8ac45..787a7d45 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -37,8 +37,7 @@ fn approx_bf16(v: Vec, digits: i32) -> Vec { fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -60,8 +59,7 @@ fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { fn run_binary(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; @@ -96,8 +94,7 @@ fn run_strided( let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); let output = new_buffer(&device, v); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); call_unary_strided( &device, command_buffer, @@ -278,8 +275,7 @@ fn binary_ops_bf16() { fn cast(v: &[T], name: &'static str) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -409,8 +405,7 @@ fn it_cast_f16_bf16() { fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); @@ -445,8 +440,7 @@ fn run_affine_strided( add: f64, ) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); @@ -595,8 +589,7 @@ fn run_index_select( let dst_el = ids.len() * left_size * right_size; let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); call_index_select( &device, &command_buffer, @@ -631,8 +624,7 @@ fn cos_f16() { fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -662,8 +654,7 @@ fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec(v: &[T], last_dim: usize, name: &'static str) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -782,8 +773,7 @@ fn run_where_cond( name: &'static str, ) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; @@ -859,8 +849,7 @@ fn run_gemm( rhs_offset: usize, ) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged;