mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Add support for max_pool2d for Metal backend (#1863)
* first pass at implementation of maxpool2d * Add definitions for other dtypes * add tests for other dtypes * Cosmetic tweaks + re-enable maxpool2d tests for metal. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -1826,5 +1826,38 @@ fn divide(m: usize, b: usize) -> NSUInteger {
|
||||
((m + b - 1) / b) as NSUInteger
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_max_pool2d(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
out_w: usize,
|
||||
out_h: usize,
|
||||
w_k: usize,
|
||||
h_k: usize,
|
||||
w_stride: usize,
|
||||
h_stride: usize,
|
||||
input: &Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let dst_el = out_w * out_h * shape[0] * shape[1];
|
||||
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(w_k, h_k, w_stride, h_stride, shape, strides, input, output)
|
||||
);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
Reference in New Issue
Block a user