Add support for max_pool2d for Metal backend (#1863)

* first pass at implementation of maxpool2d

* Add definitions for other dtypes

* add tests for other dtypes

* Cosmetic tweaks + re-enable maxpool2d tests for metal.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Thomas Santerre
2024-03-18 03:33:30 -04:00
committed by GitHub
parent 184105792f
commit 754fa1e813
5 changed files with 394 additions and 7 deletions

View File

@ -1826,5 +1826,38 @@ fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
}
#[allow(clippy::too_many_arguments)]
pub fn call_max_pool2d(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
strides: &[usize],
out_w: usize,
out_h: usize,
w_k: usize,
h_k: usize,
w_stride: usize,
h_stride: usize,
input: &Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let dst_el = out_w * out_h * shape[0] * shape[1];
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(w_k, h_k, w_stride, h_stride, shape, strides, input, output)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[cfg(test)]
mod tests;