Add support for conv_transpose2d on Metal backend (#1903)

* add support for conv transpose 2d and add bench mark for float types

* update bench calculation

* enable testing all conv operations on metal
This commit is contained in:
Thomas Santerre
2024-03-21 13:08:45 -04:00
committed by GitHub
parent ec97c98e81
commit 9563a5fee4
7 changed files with 321 additions and 76 deletions

View File

@ -1970,5 +1970,63 @@ pub fn call_conv_transpose1d(
Ok(())
}
pub struct CallConvTranspose2dCfg<'a> {
pub dilation: usize,
pub stride: usize,
pub padding: usize,
pub output_padding: usize,
pub c_out: usize,
pub out_w: usize,
pub out_h: usize,
pub b_size: usize,
pub input_dims: &'a [usize],
pub input_stride: &'a [usize],
pub kernel_dims: &'a [usize],
pub kernel_stride: &'a [usize],
pub input_offset: usize,
pub kernel_offset: usize,
}
pub fn call_conv_transpose2d(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
cfg: CallConvTranspose2dCfg,
input: &Buffer,
kernel: &Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let dst_el = cfg.c_out * cfg.out_w * cfg.out_h * cfg.b_size;
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
cfg.out_w,
cfg.out_h,
cfg.stride,
cfg.padding,
cfg.output_padding,
cfg.dilation,
cfg.input_dims,
cfg.input_stride,
cfg.kernel_dims,
cfg.kernel_stride,
(input, cfg.input_offset),
(kernel, cfg.kernel_offset),
output
)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[cfg(test)]
mod tests;