Add support for conv_transpose1d for metal backend (#1874)

* first attempt

* progress

* integrate into metal backend

* finish and get test passing

* add other dtype support

* update transpose1d dtypes supported
This commit is contained in:
Thomas Santerre
2024-03-19 03:46:58 -04:00
committed by GitHub
parent 143c481c20
commit 2a8679509e
5 changed files with 394 additions and 10 deletions

View File

@ -1859,5 +1859,58 @@ pub fn call_pool2d(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_conv_transpose1d(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
dilation: usize,
stride: usize,
padding: usize,
out_padding: usize,
c_out: usize,
l_out: usize,
b_size: usize,
src_shape: &[usize],
src_strides: &[usize],
kernel_shape: &[usize],
kernel_strides: &[usize],
input: &Buffer,
input_offset: usize,
kernel: &Buffer,
kernel_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let dst_el = c_out * l_out * b_size;
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
l_out,
stride,
padding,
out_padding,
dilation,
src_shape,
src_strides,
kernel_shape,
kernel_strides,
(input, input_offset),
(kernel, kernel_offset),
output
)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[cfg(test)]
mod tests;