mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Add support for conv_transpose2d on Metal backend (#1903)
* add support for conv transpose 2d and add bench mark for float types * update bench calculation * enable testing all conv operations on metal
This commit is contained in:
@ -1970,5 +1970,63 @@ pub fn call_conv_transpose1d(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct CallConvTranspose2dCfg<'a> {
|
||||
pub dilation: usize,
|
||||
pub stride: usize,
|
||||
pub padding: usize,
|
||||
pub output_padding: usize,
|
||||
pub c_out: usize,
|
||||
pub out_w: usize,
|
||||
pub out_h: usize,
|
||||
pub b_size: usize,
|
||||
pub input_dims: &'a [usize],
|
||||
pub input_stride: &'a [usize],
|
||||
pub kernel_dims: &'a [usize],
|
||||
pub kernel_stride: &'a [usize],
|
||||
pub input_offset: usize,
|
||||
pub kernel_offset: usize,
|
||||
}
|
||||
|
||||
pub fn call_conv_transpose2d(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
cfg: CallConvTranspose2dCfg,
|
||||
input: &Buffer,
|
||||
kernel: &Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let dst_el = cfg.c_out * cfg.out_w * cfg.out_h * cfg.b_size;
|
||||
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
cfg.out_w,
|
||||
cfg.out_h,
|
||||
cfg.stride,
|
||||
cfg.padding,
|
||||
cfg.output_padding,
|
||||
cfg.dilation,
|
||||
cfg.input_dims,
|
||||
cfg.input_stride,
|
||||
cfg.kernel_dims,
|
||||
cfg.kernel_stride,
|
||||
(input, cfg.input_offset),
|
||||
(kernel, cfg.kernel_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
Reference in New Issue
Block a user