Add support for conv_transpose2d on Metal backend (#1903)

* add support for conv transpose 2d and add bench mark for float types

* update bench calculation

* enable testing all conv operations on metal
This commit is contained in:
Thomas Santerre
2024-03-21 13:08:45 -04:00
committed by GitHub
parent ec97c98e81
commit 9563a5fee4
7 changed files with 321 additions and 76 deletions

View File

@ -2,8 +2,8 @@ use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels;
use candle_metal_kernels::Kernels;
use candle_metal_kernels::{self, CallConvTranspose2dCfg};
use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
@ -1074,12 +1074,66 @@ impl BackendStorage for MetalStorage {
fn conv_transpose2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConvTranspose2D,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &ParamsConvTranspose2D,
) -> Result<Self> {
crate::bail!("Metal conv_tranpose2d not implemented")
// Kernel shape: (c_in_k, c_out, h_k, w_k)
// Input shape: (b_size, c_in, h_in, w_in)
let (out_w, out_h) = (params.out_w(), params.out_h());
let dst_el = params.c_out * out_w * out_h * params.b_size;
let dims = l.dims();
if dims.len() != 4 {
crate::bail!("unexpected input shape for conv_transpose2d {dims:?}, expected 4")
}
let k_dims = kernel_l.dims();
if k_dims.len() != 4 {
crate::bail!("unexpected kernel shape for conv_transpose2d {k_dims:?}, expected 4")
}
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "conv_transpose2d")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "conv_transpose2d_f32",
DType::F16 => "conv_transpose2d_f16",
DType::BF16 => "conv_transpose2d_bf16",
dtype => crate::bail!("Metal conv_transpose2d {dtype:?} not implemented"),
};
candle_metal_kernels::call_conv_transpose2d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
CallConvTranspose2dCfg {
dilation: params.dilation,
stride: params.stride,
padding: params.padding,
output_padding: params.output_padding,
c_out: params.c_out,
out_h: out_h,
out_w: out_w,
b_size: params.b_size,
input_dims: l.dims(),
input_stride: l.stride(),
kernel_dims: kernel_l.dims(),
kernel_stride: kernel_l.stride(),
input_offset: l.start_offset() * self.dtype.size_in_bytes(),
kernel_offset: kernel_l.start_offset() * kernel.dtype.size_in_bytes(),
},
&self.buffer,
&kernel.buffer,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
}
fn avg_pool2d(