|
|
|
@ -204,17 +204,15 @@ type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipeline
|
|
|
|
|
pub struct Kernels {
|
|
|
|
|
libraries: RwLock<Libraries>,
|
|
|
|
|
pipelines: RwLock<Pipelines>,
|
|
|
|
|
fence: metal::Fence,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl Kernels {
|
|
|
|
|
pub fn new(fence: metal::Fence) -> Self {
|
|
|
|
|
pub fn new() -> Self {
|
|
|
|
|
let libraries = RwLock::new(Libraries::new());
|
|
|
|
|
let pipelines = RwLock::new(Pipelines::new());
|
|
|
|
|
Self {
|
|
|
|
|
libraries,
|
|
|
|
|
pipelines,
|
|
|
|
|
fence,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -334,7 +332,6 @@ pub fn call_unary_contiguous(
|
|
|
|
|
) -> Result<(), MetalKernelError> {
|
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
set_params!(encoder, (length, input, output));
|
|
|
|
@ -343,7 +340,6 @@ pub fn call_unary_contiguous(
|
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
@ -365,7 +361,6 @@ pub fn call_unary_strided(
|
|
|
|
|
|
|
|
|
|
let num_dims: usize = shape.len();
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
let length: usize = shape.iter().product();
|
|
|
|
@ -387,7 +382,6 @@ pub fn call_unary_strided(
|
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
@ -406,7 +400,6 @@ pub fn call_binary_contiguous(
|
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
|
|
|
|
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
set_params!(encoder, (length, left, right, output));
|
|
|
|
@ -417,7 +410,6 @@ pub fn call_binary_contiguous(
|
|
|
|
|
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
@ -442,7 +434,6 @@ pub fn call_binary_strided(
|
|
|
|
|
let num_dims: usize = shape.len();
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
let width: usize = shape.iter().product();
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
let length: usize = shape.iter().product();
|
|
|
|
@ -467,7 +458,6 @@ pub fn call_binary_strided(
|
|
|
|
|
encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
@ -486,7 +476,6 @@ pub fn call_cast_contiguous(
|
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
|
|
|
|
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
set_params!(encoder, (length, (input, input_offset), output));
|
|
|
|
@ -495,7 +484,6 @@ pub fn call_cast_contiguous(
|
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
@ -515,7 +503,6 @@ pub fn call_cast_strided(
|
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
|
|
|
|
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
let length: usize = shape.iter().product();
|
|
|
|
@ -537,7 +524,6 @@ pub fn call_cast_strided(
|
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
@ -557,7 +543,6 @@ pub fn call_reduce_contiguous(
|
|
|
|
|
let elements_to_sum = length / out_length;
|
|
|
|
|
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
set_params!(
|
|
|
|
@ -586,7 +571,6 @@ pub fn call_reduce_contiguous(
|
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
@ -608,7 +592,6 @@ pub fn call_reduce_strided(
|
|
|
|
|
let elements_to_sum = length / out_length;
|
|
|
|
|
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
set_params!(
|
|
|
|
@ -644,7 +627,6 @@ pub fn call_reduce_strided(
|
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
@ -663,7 +645,6 @@ pub fn call_last_softmax(
|
|
|
|
|
) -> Result<(), MetalKernelError> {
|
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
set_params!(
|
|
|
|
@ -694,7 +675,6 @@ pub fn call_last_softmax(
|
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
@ -714,7 +694,6 @@ pub fn call_affine(
|
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
|
|
|
|
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
set_params!(encoder, (size, mul, add, input, output));
|
|
|
|
@ -723,7 +702,6 @@ pub fn call_affine(
|
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
@ -746,7 +724,6 @@ pub fn call_affine_strided(
|
|
|
|
|
let size: usize = shape.iter().product();
|
|
|
|
|
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
set_params!(
|
|
|
|
@ -767,8 +744,8 @@ pub fn call_affine_strided(
|
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -786,7 +763,6 @@ pub fn call_powf(
|
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
|
|
|
|
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
set_params!(encoder, (size, mul, input, output));
|
|
|
|
@ -795,8 +771,8 @@ pub fn call_powf(
|
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -817,7 +793,6 @@ pub fn call_powf_strided(
|
|
|
|
|
let size: usize = shape.iter().product();
|
|
|
|
|
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
set_params!(
|
|
|
|
@ -837,7 +812,6 @@ pub fn call_powf_strided(
|
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
@ -856,7 +830,6 @@ pub fn call_elu(
|
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
|
|
|
|
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
set_params!(encoder, (size, mul, input, output));
|
|
|
|
@ -865,7 +838,6 @@ pub fn call_elu(
|
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
@ -887,7 +859,6 @@ pub fn call_elu_strided(
|
|
|
|
|
let size: usize = shape.iter().product();
|
|
|
|
|
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
set_params!(
|
|
|
|
@ -907,7 +878,6 @@ pub fn call_elu_strided(
|
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
@ -929,7 +899,6 @@ pub fn call_where_cond_strided(
|
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
|
|
|
|
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
let size: usize = shape.iter().product();
|
|
|
|
@ -958,7 +927,6 @@ pub fn call_where_cond_strided(
|
|
|
|
|
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
@ -984,8 +952,6 @@ pub fn call_index_select(
|
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
|
|
|
|
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
set_params!(
|
|
|
|
@ -1008,8 +974,8 @@ pub fn call_index_select(
|
|
|
|
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1036,8 +1002,6 @@ pub fn call_gather(
|
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
|
|
|
|
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
set_params!(
|
|
|
|
@ -1060,8 +1024,8 @@ pub fn call_gather(
|
|
|
|
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1088,8 +1052,6 @@ pub fn call_scatter_add(
|
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
|
|
|
|
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
set_params!(
|
|
|
|
@ -1112,8 +1074,8 @@ pub fn call_scatter_add(
|
|
|
|
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1141,8 +1103,6 @@ pub fn call_index_add(
|
|
|
|
|
|
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
set_params!(
|
|
|
|
@ -1166,8 +1126,8 @@ pub fn call_index_add(
|
|
|
|
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1370,7 +1330,6 @@ pub fn call_gemm(
|
|
|
|
|
let block_bytes = block_elements * bytes;
|
|
|
|
|
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
|
|
|
|
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
|
|
|
@ -1415,7 +1374,6 @@ pub fn call_gemm(
|
|
|
|
|
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(grid_size, group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
@ -1440,8 +1398,8 @@ pub fn call_im2col1d_strided(
|
|
|
|
|
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
set_params!(
|
|
|
|
|
encoder,
|
|
|
|
|
(
|
|
|
|
@ -1460,7 +1418,6 @@ pub fn call_im2col1d_strided(
|
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
@ -1490,8 +1447,8 @@ pub fn call_im2col_strided(
|
|
|
|
|
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
set_params!(
|
|
|
|
|
encoder,
|
|
|
|
|
(
|
|
|
|
@ -1512,9 +1469,7 @@ pub fn call_im2col_strided(
|
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1538,8 +1493,8 @@ pub fn call_upsample_nearest_2d(
|
|
|
|
|
let scale_h = shape[3] as f32 / out_h as f32;
|
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
|
|
|
|
|
|
set_params!(
|
|
|
|
|
encoder,
|
|
|
|
|
(
|
|
|
|
@ -1556,7 +1511,6 @@ pub fn call_upsample_nearest_2d(
|
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
|
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
|
encoder.update_fence(&kernels.fence);
|
|
|
|
|
encoder.end_encoding();
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
|