Use RAII for terminating the encoding. (#2353)

This commit is contained in:
Laurent Mazare
2024-07-24 15:29:56 +01:00
committed by GitHub
parent a925ae6bc6
commit ddafc61055
2 changed files with 69 additions and 61 deletions

View File

@ -1,6 +1,6 @@
use metal::{ use metal::{
Buffer, CompileOptions, ComputePipelineState, Device, Function, FunctionConstantValues, Buffer, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Device, Function,
Library, MTLDataType, MTLSize, NSUInteger, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
}; };
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::c_void; use std::ffi::c_void;
@ -311,6 +311,7 @@ pub fn call_copy2d(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
encoder, encoder,
@ -333,7 +334,6 @@ pub fn call_copy2d(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_threads(grid_dims, group_dims); encoder.dispatch_threads(grid_dims, group_dims);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -349,6 +349,7 @@ pub fn call_unary_contiguous_tiled(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let tile_size = 2; let tile_size = 2;
let tiles = (length + tile_size - 1) / tile_size; let tiles = (length + tile_size - 1) / tile_size;
@ -360,7 +361,6 @@ pub fn call_unary_contiguous_tiled(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -376,6 +376,7 @@ pub fn call_unary_contiguous(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
@ -385,7 +386,6 @@ pub fn call_unary_contiguous(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -405,6 +405,7 @@ pub fn call_unary_strided(
let length: usize = shape.iter().product(); let length: usize = shape.iter().product();
let num_dims: usize = shape.len(); let num_dims: usize = shape.len();
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
@ -412,7 +413,6 @@ pub fn call_unary_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -430,6 +430,7 @@ pub fn call_binary_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, &left, &right, output)); set_params!(encoder, (length, &left, &right, output));
@ -440,7 +441,6 @@ pub fn call_binary_contiguous(
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -461,6 +461,7 @@ pub fn call_binary_strided(
let num_dims: usize = shape.len(); let num_dims: usize = shape.len();
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let width: usize = shape.iter().product(); let width: usize = shape.iter().product();
let length: usize = shape.iter().product(); let length: usize = shape.iter().product();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
@ -483,7 +484,6 @@ pub fn call_binary_strided(
encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -501,6 +501,7 @@ pub fn call_cast_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, &input, output)); set_params!(encoder, (length, &input, output));
@ -509,7 +510,6 @@ pub fn call_cast_contiguous(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -527,6 +527,7 @@ pub fn call_cast_strided(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product(); let length: usize = shape.iter().product();
@ -541,7 +542,6 @@ pub fn call_cast_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -560,6 +560,7 @@ pub fn call_reduce_contiguous(
let elements_to_sum = length / out_length; let elements_to_sum = length / out_length;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, elements_to_sum, &input, output)); set_params!(encoder, (length, elements_to_sum, &input, output));
@ -585,7 +586,6 @@ pub fn call_reduce_contiguous(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -606,6 +606,7 @@ pub fn call_reduce_strided(
let elements_to_sum = length / out_length; let elements_to_sum = length / out_length;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -634,7 +635,6 @@ pub fn call_reduce_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -652,6 +652,7 @@ pub fn call_last_softmax(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -682,7 +683,6 @@ pub fn call_last_softmax(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -703,6 +703,7 @@ pub fn call_rms_norm(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -741,7 +742,6 @@ pub fn call_rms_norm(
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64); encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -764,6 +764,7 @@ pub fn call_layer_norm(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -803,7 +804,6 @@ pub fn call_layer_norm(
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64); encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -825,6 +825,7 @@ pub fn call_rope_i(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -844,7 +845,6 @@ pub fn call_rope_i(
encoder.use_resource(sin, metal::MTLResourceUsage::Read); encoder.use_resource(sin, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -868,6 +868,7 @@ pub fn call_rope_thd(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -889,7 +890,6 @@ pub fn call_rope_thd(
encoder.use_resource(sin, metal::MTLResourceUsage::Read); encoder.use_resource(sin, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -912,6 +912,7 @@ pub fn call_rope(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -932,7 +933,6 @@ pub fn call_rope(
encoder.use_resource(sin, metal::MTLResourceUsage::Read); encoder.use_resource(sin, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -951,6 +951,7 @@ pub fn call_affine(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, add, &input, output)); set_params!(encoder, (size, mul, add, &input, output));
@ -959,7 +960,6 @@ pub fn call_affine(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -980,6 +980,7 @@ pub fn call_affine_strided(
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1000,7 +1001,6 @@ pub fn call_affine_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -1018,6 +1018,7 @@ pub fn call_powf(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, &input, output)); set_params!(encoder, (size, mul, &input, output));
@ -1026,7 +1027,6 @@ pub fn call_powf(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -1046,6 +1046,7 @@ pub fn call_powf_strided(
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1057,7 +1058,6 @@ pub fn call_powf_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -1075,6 +1075,7 @@ pub fn call_elu(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, &input, output)); set_params!(encoder, (size, mul, &input, output));
@ -1083,7 +1084,6 @@ pub fn call_elu(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -1103,6 +1103,7 @@ pub fn call_elu_strided(
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1114,7 +1115,6 @@ pub fn call_elu_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -1136,6 +1136,7 @@ pub fn call_where_cond_strided(
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
@ -1164,7 +1165,6 @@ pub fn call_where_cond_strided(
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -1192,6 +1192,7 @@ pub fn call_index_select(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
@ -1218,7 +1219,6 @@ pub fn call_index_select(
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -1243,6 +1243,7 @@ pub fn call_gather(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
@ -1266,7 +1267,6 @@ pub fn call_gather(
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -1292,6 +1292,7 @@ pub fn call_scatter_add(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
@ -1315,7 +1316,6 @@ pub fn call_scatter_add(
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -1342,6 +1342,7 @@ pub fn call_index_add(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
@ -1366,7 +1367,6 @@ pub fn call_index_add(
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -1573,6 +1573,7 @@ pub fn call_gemm(
let block_bytes = block_elements * bytes; let block_bytes = block_elements * bytes;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, block_bytes.into()); encoder.set_threadgroup_memory_length(0, block_bytes.into());
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
@ -1615,8 +1616,6 @@ pub fn call_gemm(
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size); encoder.dispatch_thread_groups(grid_size, group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -1637,6 +1636,7 @@ pub fn call_im2col1d_strided(
let dst_el = shape[0] * l_out * shape[1] * k_size; let dst_el = shape[0] * l_out * shape[1] * k_size;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1646,8 +1646,6 @@ pub fn call_im2col1d_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -1670,6 +1668,7 @@ pub fn call_col2im1d(
let dst_el = shape[0] * c_out * l_out; let dst_el = shape[0] * c_out * l_out;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1679,8 +1678,6 @@ pub fn call_col2im1d(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -1706,6 +1703,7 @@ pub fn call_im2col_strided(
let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k; let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1718,8 +1716,6 @@ pub fn call_im2col_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -1742,6 +1738,7 @@ pub fn call_upsample_nearest_2d(
let scale_h = shape[3] as f32 / out_h as f32; let scale_h = shape[3] as f32 / out_h as f32;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
encoder, encoder,
@ -1750,8 +1747,6 @@ pub fn call_upsample_nearest_2d(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -1774,6 +1769,7 @@ pub fn call_random_uniform(
} }
let pipeline = kernels.load_pipeline(device, Source::Random, name)?; let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let odd = (length % 2 != 0) as usize; let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
@ -1788,8 +1784,6 @@ pub fn call_random_uniform(
); );
encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -1807,6 +1801,7 @@ pub fn call_random_normal(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Random, name)?; let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let odd = (length % 2 != 0) as usize; let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
@ -1821,8 +1816,6 @@ pub fn call_random_normal(
); );
encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -1962,6 +1955,7 @@ pub fn call_quantized_matmul_mv_t(
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1993,8 +1987,6 @@ pub fn call_quantized_matmul_mv_t(
encoder.use_resource(dst, metal::MTLResourceUsage::Write); encoder.use_resource(dst, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup); encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -2023,6 +2015,7 @@ pub fn call_pool2d(
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
encoder, encoder,
@ -2031,7 +2024,6 @@ pub fn call_pool2d(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -2062,6 +2054,7 @@ pub fn call_conv_transpose1d(
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
encoder, encoder,
@ -2084,7 +2077,6 @@ pub fn call_conv_transpose1d(
encoder.use_resource(kernel, metal::MTLResourceUsage::Read); encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -2120,6 +2112,7 @@ pub fn call_conv_transpose2d(
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
encoder, encoder,
@ -2143,7 +2136,6 @@ pub fn call_conv_transpose2d(
encoder.use_resource(kernel, metal::MTLResourceUsage::Read); encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }
@ -2161,6 +2153,7 @@ pub fn call_arg_sort(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Sort, name)?; let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
let encoder = ep.encoder(); let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64)); set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
@ -2180,7 +2173,6 @@ pub fn call_arg_sort(
encoder.use_resource(dst, metal::MTLResourceUsage::Write); encoder.use_resource(dst, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64); encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
ep.maybe_end_encoding(encoder);
Ok(()) Ok(())
} }

View File

@ -162,24 +162,40 @@ macro_rules! set_params {
} }
pub trait EncoderProvider { pub trait EncoderProvider {
fn encoder(&self) -> &ComputeCommandEncoderRef; type Encoder<'a>: AsRef<metal::ComputeCommandEncoderRef>
fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef); where
Self: 'a;
fn encoder<'a>(&'a self) -> Self::Encoder<'a>;
}
pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef);
impl<'a> Drop for WrappedEncoder<'a> {
fn drop(&mut self) {
self.0.end_encoding()
}
}
impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> {
fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
&self.0
}
} }
impl EncoderProvider for &metal::CommandBuffer { impl EncoderProvider for &metal::CommandBuffer {
fn encoder(&self) -> &ComputeCommandEncoderRef { type Encoder<'a> = WrappedEncoder<'a>
self.new_compute_command_encoder() where
} Self: 'a;
fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) { fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
enc.end_encoding() WrappedEncoder(self.new_compute_command_encoder())
} }
} }
impl EncoderProvider for &metal::CommandBufferRef { impl EncoderProvider for &metal::CommandBufferRef {
fn encoder(&self) -> &ComputeCommandEncoderRef { type Encoder<'a> = WrappedEncoder<'a>
self.new_compute_command_encoder() where
} Self: 'a;
fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) { fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
enc.end_encoding() WrappedEncoder(self.new_compute_command_encoder())
} }
} }