Adding upsample_nearest_2d.

This commit is contained in:
Nicolas Patry
2023-12-25 14:25:19 +01:00
parent 1505d85276
commit 13a5d15ebc
3 changed files with 137 additions and 2 deletions

View File

@ -1518,6 +1518,50 @@ pub fn call_im2col_strided(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_upsample_nearest_2d(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
strides: &[usize],
out_w: usize,
out_h: usize,
input: &Buffer,
input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
let dst_el = out_w * out_h * shape[0] * shape[1];
let scale_w = shape[2] as f32 / out_w as f32;
let scale_h = shape[3] as f32 / out_h as f32;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
out_w,
out_h,
scale_w,
scale_h,
shape,
strides,
(input, input_offset),
output
)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
}