Add support for conv_transpose1d for metal backend (#1874)

* first attempt

* progress

* integrate into metal backend

* finish and get test passing

* add other dtype support

* update transpose1d dtypes supported
This commit is contained in:
Thomas Santerre
2024-03-19 03:46:58 -04:00
committed by GitHub
parent 143c481c20
commit 2a8679509e
5 changed files with 394 additions and 10 deletions

View File

@ -948,12 +948,54 @@ impl BackendStorage for MetalStorage {
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConvTranspose1D,
layout: &Layout,
k: &Self,
k_layout: &Layout,
params: &ParamsConvTranspose1D,
) -> Result<Self> {
crate::bail!("Metal conv_transpose1d not implemented")
let device = self.device().clone();
let l_out = params.l_out();
let dst_el = params.c_out * l_out * params.b_size;
let dst_el = params.c_out * l_out * params.b_size;
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "conv_transpose1d_f32",
DType::F16 => "conv_transpose1d_f16",
DType::BF16 => "conv_transpose1d_bf16",
DType::U32 => "conv_transpose1d_u32",
DType::U8 => "conv_transpose1d_u8",
dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"),
};
candle_metal_kernels::call_conv_transpose1d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
params.dilation,
params.stride,
params.padding,
params.output_padding,
params.c_out,
l_out,
params.b_size,
layout.dims(),
layout.stride(),
k_layout.dims(),
k_layout.stride(),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&k.buffer,
k_layout.start_offset() * k.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
}
fn conv2d(