mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add support for conv_transpose2d on Metal backend (#1903)
* add support for conv transpose 2d and add bench mark for float types * update bench calculation * enable testing all conv operations on metal
This commit is contained in:
@ -2,8 +2,8 @@ use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels;
|
||||
use candle_metal_kernels::Kernels;
|
||||
use candle_metal_kernels::{self, CallConvTranspose2dCfg};
|
||||
use metal;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
@ -1074,12 +1074,66 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &ParamsConvTranspose2D,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
crate::bail!("Metal conv_tranpose2d not implemented")
|
||||
// Kernel shape: (c_in_k, c_out, h_k, w_k)
|
||||
// Input shape: (b_size, c_in, h_in, w_in)
|
||||
let (out_w, out_h) = (params.out_w(), params.out_h());
|
||||
let dst_el = params.c_out * out_w * out_h * params.b_size;
|
||||
|
||||
let dims = l.dims();
|
||||
if dims.len() != 4 {
|
||||
crate::bail!("unexpected input shape for conv_transpose2d {dims:?}, expected 4")
|
||||
}
|
||||
|
||||
let k_dims = kernel_l.dims();
|
||||
if k_dims.len() != 4 {
|
||||
crate::bail!("unexpected kernel shape for conv_transpose2d {k_dims:?}, expected 4")
|
||||
}
|
||||
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "conv_transpose2d")?;
|
||||
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "conv_transpose2d_f32",
|
||||
DType::F16 => "conv_transpose2d_f16",
|
||||
DType::BF16 => "conv_transpose2d_bf16",
|
||||
dtype => crate::bail!("Metal conv_transpose2d {dtype:?} not implemented"),
|
||||
};
|
||||
|
||||
candle_metal_kernels::call_conv_transpose2d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
CallConvTranspose2dCfg {
|
||||
dilation: params.dilation,
|
||||
stride: params.stride,
|
||||
padding: params.padding,
|
||||
output_padding: params.output_padding,
|
||||
c_out: params.c_out,
|
||||
out_h: out_h,
|
||||
out_w: out_w,
|
||||
b_size: params.b_size,
|
||||
input_dims: l.dims(),
|
||||
input_stride: l.stride(),
|
||||
kernel_dims: kernel_l.dims(),
|
||||
kernel_stride: kernel_l.stride(),
|
||||
input_offset: l.start_offset() * self.dtype.size_in_bytes(),
|
||||
kernel_offset: kernel_l.start_offset() * kernel.dtype.size_in_bytes(),
|
||||
},
|
||||
&self.buffer,
|
||||
&kernel.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
|
||||
}
|
||||
|
||||
fn avg_pool2d(
|
||||
|
Reference in New Issue
Block a user