Add a metal kernel for col2im1d. (#2214)

* Add a metal kernel for col2im1d.

* Enable the col2im variant.

* Bugfix.

* Revert the quantized tweak.
This commit is contained in:
Laurent Mazare
2024-05-25 11:03:23 +02:00
committed by GitHub
parent 3ceca9901a
commit 0814dfd148
3 changed files with 189 additions and 35 deletions

View File

@ -1651,6 +1651,39 @@ pub fn call_im2col1d_strided(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_col2im1d(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
k_size: usize,
stride: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
let l_in = shape[1];
let c_out = shape[2];
let l_out = (l_in - 1) * stride + k_size;
let dst_el = shape[0] * c_out * l_out;
let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(dst_el, l_out, l_in, c_out, k_size, stride, &input, output)
);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_im2col_strided(
device: &Device,