Cleanup the fence.

This commit is contained in:
Nicolas Patry
2024-01-05 21:57:07 +01:00
parent c8c603ce96
commit 3aefc709c7
4 changed files with 12 additions and 94 deletions

View File

@ -84,13 +84,8 @@ pub struct MetalDevice {
command_buffer_index: Arc<RwLock<usize>>, command_buffer_index: Arc<RwLock<usize>>,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc) /// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
compute_per_buffer: usize, compute_per_buffer: usize,
/// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the
/// execution order to be linear.
/// It could be relaxed in some circumstances, by managing ourselves the dependencies in the
/// compute graph.
// fence: metal::Fence,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`], both fences need to match /// Heavily used by [`candle_metal_kernels`]
kernels: Arc<candle_metal_kernels::Kernels>, kernels: Arc<candle_metal_kernels::Kernels>,
/// Simple allocator struct. /// Simple allocator struct.
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over. /// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
@ -131,10 +126,6 @@ impl MetalDevice {
&self.device &self.device
} }
// pub(crate) fn fence(&self) -> &metal::Fence {
// &self.fence
// }
pub fn command_queue(&self) -> &CommandQueue { pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue &self.command_queue
} }
@ -225,10 +216,8 @@ impl MetalDevice {
let command_buffer = self.command_buffer()?; let command_buffer = self.command_buffer()?;
command_buffer.set_label("with_data"); command_buffer.set_label("with_data");
let blit = command_buffer.new_blit_command_encoder(); let blit = command_buffer.new_blit_command_encoder();
// blit.wait_for_fence(&self.fence);
blit.set_label("with_data_blit"); blit.set_label("with_data_blit");
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
// blit.update_fence(&self.fence);
blit.end_encoding(); blit.end_encoding();
// This is necessary, for mmaped safetensors // This is necessary, for mmaped safetensors
@ -251,7 +240,6 @@ impl MetalDevice {
let command_buffer = self.command_buffer()?; let command_buffer = self.command_buffer()?;
command_buffer.set_label("zeros"); command_buffer.set_label("zeros");
let blit = command_buffer.new_blit_command_encoder(); let blit = command_buffer.new_blit_command_encoder();
// blit.wait_for_fence(&self.fence);
blit.fill_buffer( blit.fill_buffer(
&buffer, &buffer,
metal::NSRange { metal::NSRange {
@ -260,7 +248,6 @@ impl MetalDevice {
}, },
0, 0,
); );
// blit.update_fence(&self.fence);
blit.end_encoding(); blit.end_encoding();
Ok(buffer) Ok(buffer)
} }
@ -1543,9 +1530,7 @@ impl MetalStorage {
command_buffer.set_label("to_cpu"); command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder(); let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu"); blit.set_label("blit_to_cpu");
// blit.wait_for_fence(&self.device.fence);
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
// blit.update_fence(&self.device.fence);
blit.end_encoding(); blit.end_encoding();
} }
self.device.wait_until_completed()?; self.device.wait_until_completed()?;
@ -1563,7 +1548,6 @@ impl BackendDevice for MetalDevice {
command_buffer.enqueue(); command_buffer.enqueue();
let command_buffer = Arc::new(RwLock::new(command_buffer)); let command_buffer = Arc::new(RwLock::new(command_buffer));
let command_buffer_index = Arc::new(RwLock::new(0)); let command_buffer_index = Arc::new(RwLock::new(0));
// let fence = device.new_fence();
let kernels = Arc::new(Kernels::new()); let kernels = Arc::new(Kernels::new());
let buffers = Arc::new(RwLock::new(HashMap::new())); let buffers = Arc::new(RwLock::new(HashMap::new()));
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
@ -1572,7 +1556,6 @@ impl BackendDevice for MetalDevice {
}; };
Ok(Self { Ok(Self {
device, device,
// fence,
command_queue, command_queue,
command_buffer, command_buffer,
command_buffer_index, command_buffer_index,

View File

@ -32,9 +32,7 @@ impl QMetalStorage {
command_buffer.set_label("to_cpu"); command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder(); let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu"); blit.set_label("blit_to_cpu");
// blit.wait_for_fence(&self.device.fence());
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
// blit.update_fence(&self.device.fence());
blit.end_encoding(); blit.end_encoding();
self.device.wait_until_completed()?; self.device.wait_until_completed()?;
let mut out = vec![0.0; elem_count]; let mut out = vec![0.0; elem_count];

View File

@ -219,7 +219,6 @@ type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipeline
pub struct Kernels { pub struct Kernels {
libraries: RwLock<Libraries>, libraries: RwLock<Libraries>,
pipelines: RwLock<Pipelines>, pipelines: RwLock<Pipelines>,
// fence: metal::Fence,
} }
impl Kernels { impl Kernels {
@ -229,7 +228,6 @@ impl Kernels {
Self { Self {
libraries, libraries,
pipelines, pipelines,
// fence,
} }
} }
@ -350,7 +348,6 @@ pub fn call_unary_contiguous(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, output)); set_params!(encoder, (length, input, output));
@ -359,7 +356,6 @@ pub fn call_unary_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -381,7 +377,6 @@ pub fn call_unary_strided(
let num_dims: usize = shape.len(); let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product(); let length: usize = shape.iter().product();
@ -403,7 +398,6 @@ pub fn call_unary_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -422,7 +416,6 @@ pub fn call_binary_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, left, right, output)); set_params!(encoder, (length, left, right, output));
@ -433,7 +426,6 @@ pub fn call_binary_contiguous(
encoder.use_resource(right, metal::MTLResourceUsage::Read); encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -458,7 +450,6 @@ pub fn call_binary_strided(
let num_dims: usize = shape.len(); let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
let width: usize = shape.iter().product(); let width: usize = shape.iter().product();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product(); let length: usize = shape.iter().product();
@ -483,7 +474,6 @@ pub fn call_binary_strided(
encoder.use_resource(right_input, metal::MTLResourceUsage::Read); encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -502,7 +492,6 @@ pub fn call_cast_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, (input, input_offset), output)); set_params!(encoder, (length, (input, input_offset), output));
@ -511,7 +500,6 @@ pub fn call_cast_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -531,7 +519,6 @@ pub fn call_cast_strided(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product(); let length: usize = shape.iter().product();
@ -553,7 +540,6 @@ pub fn call_cast_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -573,7 +559,6 @@ pub fn call_reduce_contiguous(
let elements_to_sum = length / out_length; let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -602,7 +587,6 @@ pub fn call_reduce_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -624,7 +608,6 @@ pub fn call_reduce_strided(
let elements_to_sum = length / out_length; let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -660,7 +643,6 @@ pub fn call_reduce_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -679,7 +661,6 @@ pub fn call_last_softmax(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -710,7 +691,6 @@ pub fn call_last_softmax(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -730,7 +710,6 @@ pub fn call_affine(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, add, input, output)); set_params!(encoder, (size, mul, add, input, output));
@ -739,7 +718,6 @@ pub fn call_affine(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -762,7 +740,6 @@ pub fn call_affine_strided(
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -783,7 +760,6 @@ pub fn call_affine_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -802,7 +778,6 @@ pub fn call_powf(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output)); set_params!(encoder, (size, mul, input, output));
@ -811,7 +786,6 @@ pub fn call_powf(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -833,7 +807,6 @@ pub fn call_powf_strided(
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -853,7 +826,6 @@ pub fn call_powf_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -872,7 +844,6 @@ pub fn call_elu(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output)); set_params!(encoder, (size, mul, input, output));
@ -881,7 +852,6 @@ pub fn call_elu(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -903,7 +873,6 @@ pub fn call_elu_strided(
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -923,7 +892,6 @@ pub fn call_elu_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -945,7 +913,6 @@ pub fn call_where_cond_strided(
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
@ -974,7 +941,6 @@ pub fn call_where_cond_strided(
encoder.use_resource(right, metal::MTLResourceUsage::Read); encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -1001,7 +967,6 @@ pub fn call_index_select(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1024,7 +989,6 @@ pub fn call_index_select(
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -1053,7 +1017,6 @@ pub fn call_gather(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1076,7 +1039,6 @@ pub fn call_gather(
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -1105,7 +1067,6 @@ pub fn call_scatter_add(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1128,7 +1089,6 @@ pub fn call_scatter_add(
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -1158,7 +1118,6 @@ pub fn call_index_add(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1182,7 +1141,6 @@ pub fn call_index_add(
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -1386,7 +1344,6 @@ pub fn call_gemm(
let block_bytes = block_elements * bytes; let block_bytes = block_elements * bytes;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, block_bytes.into()); encoder.set_threadgroup_memory_length(0, block_bytes.into());
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
@ -1430,7 +1387,6 @@ pub fn call_gemm(
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size); encoder.dispatch_thread_groups(grid_size, group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
@ -1455,7 +1411,6 @@ pub fn call_im2col1d_strided(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
encoder, encoder,
@ -1475,7 +1430,6 @@ pub fn call_im2col1d_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
@ -1505,7 +1459,6 @@ pub fn call_im2col_strided(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
encoder, encoder,
@ -1527,7 +1480,6 @@ pub fn call_im2col_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
@ -1553,7 +1505,6 @@ pub fn call_upsample_nearest_2d(
let scale_h = shape[3] as f32 / out_h as f32; let scale_h = shape[3] as f32 / out_h as f32;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
encoder, encoder,
@ -1571,7 +1522,6 @@ pub fn call_upsample_nearest_2d(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
@ -1710,7 +1660,6 @@ pub fn call_quantized_matmul_t(
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1743,7 +1692,6 @@ pub fn call_quantized_matmul_t(
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup); encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())

View File

@ -37,8 +37,7 @@ fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> { fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
let device = device(); let device = device();
let fence = device.new_fence(); let kernels = Kernels::new();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v); let input = new_buffer(&device, v);
@ -60,8 +59,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> { fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
let device = device(); let device = device();
let fence = device.new_fence(); let kernels = Kernels::new();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged; let options = MTLResourceOptions::StorageModeManaged;
@ -96,8 +94,7 @@ fn run_strided<T: Clone>(
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v); let input = new_buffer(&device, v);
let output = new_buffer(&device, v); let output = new_buffer(&device, v);
let fence = device.new_fence(); let kernels = Kernels::new();
let kernels = Kernels::new(fence);
call_unary_strided( call_unary_strided(
&device, &device,
command_buffer, command_buffer,
@ -278,8 +275,7 @@ fn binary_ops_bf16() {
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device(); let device = device();
let fence = device.new_fence(); let kernels = Kernels::new();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v); let input = new_buffer(&device, v);
@ -409,8 +405,7 @@ fn it_cast_f16_bf16() {
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let device = device(); let device = device();
let fence = device.new_fence(); let kernels = Kernels::new();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
@ -445,8 +440,7 @@ fn run_affine_strided<T: Clone>(
add: f64, add: f64,
) -> Vec<T> { ) -> Vec<T> {
let device = device(); let device = device();
let fence = device.new_fence(); let kernels = Kernels::new();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
@ -595,8 +589,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let dst_el = ids.len() * left_size * right_size; let dst_el = ids.len() * left_size * right_size;
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
let fence = device.new_fence(); let kernels = Kernels::new();
let kernels = Kernels::new(fence);
call_index_select( call_index_select(
&device, &device,
&command_buffer, &command_buffer,
@ -631,8 +624,7 @@ fn cos_f16() {
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> { fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
let device = device(); let device = device();
let fence = device.new_fence(); let kernels = Kernels::new();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v); let input = new_buffer(&device, v);
@ -662,8 +654,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> { fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
let device = device(); let device = device();
let fence = device.new_fence(); let kernels = Kernels::new();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v); let input = new_buffer(&device, v);
@ -782,8 +773,7 @@ fn run_where_cond<I: Clone, T: Clone>(
name: &'static str, name: &'static str,
) -> Vec<T> { ) -> Vec<T> {
let device = device(); let device = device();
let fence = device.new_fence(); let kernels = Kernels::new();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged; let options = MTLResourceOptions::StorageModeManaged;
@ -859,8 +849,7 @@ fn run_gemm<T: Clone>(
rhs_offset: usize, rhs_offset: usize,
) -> Vec<T> { ) -> Vec<T> {
let device = device(); let device = device();
let fence = device.new_fence(); let kernels = Kernels::new();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged; let options = MTLResourceOptions::StorageModeManaged;