Add support for conv_transpose1d for metal backend (#1874)

* first attempt

* progress

* integrate into metal backend

* finish and get test passing

* add other dtype support

* update transpose1d dtypes supported
This commit is contained in:
Thomas Santerre
2024-03-19 03:46:58 -04:00
committed by GitHub
parent 143c481c20
commit 2a8679509e
5 changed files with 394 additions and 10 deletions

View File

@ -948,12 +948,54 @@ impl BackendStorage for MetalStorage {
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConvTranspose1D,
layout: &Layout,
k: &Self,
k_layout: &Layout,
params: &ParamsConvTranspose1D,
) -> Result<Self> {
crate::bail!("Metal conv_transpose1d not implemented")
let device = self.device().clone();
let l_out = params.l_out();
let dst_el = params.c_out * l_out * params.b_size;
let dst_el = params.c_out * l_out * params.b_size;
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "conv_transpose1d_f32",
DType::F16 => "conv_transpose1d_f16",
DType::BF16 => "conv_transpose1d_bf16",
DType::U32 => "conv_transpose1d_u32",
DType::U8 => "conv_transpose1d_u8",
dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"),
};
candle_metal_kernels::call_conv_transpose1d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
params.dilation,
params.stride,
params.padding,
params.output_padding,
params.c_out,
l_out,
params.b_size,
layout.dims(),
layout.stride(),
k_layout.dims(),
k_layout.stride(),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&k.buffer,
k_layout.start_offset() * k.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
}
fn conv2d(

View File

@ -54,11 +54,6 @@ fn conv1d(dev: &Device) -> Result<()> {
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
// conv-transposes are not implemented for metal.
if dev.is_metal() {
return Ok(());
}
let w = w.transpose(0, 1)?;
// The CPU kernels applied in the contiguous and non contiguous cases are different.
for w in [w.clone(), w.contiguous()?] {

View File

@ -335,6 +335,76 @@ kernel void FN_NAME( \
max_pool2d<TYPENAME>(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \
} \
// Naive implementation of conv_transpose1d.
template <typename T, typename A>
METAL_FUNC void conv_transpose1d(
constant size_t &l_out,
constant size_t &stride,
constant size_t &padding,
constant size_t &out_padding,
constant size_t &dilation,
constant size_t *src_dims,
constant size_t *src_strides,
constant size_t *k_dims,
constant size_t *k_strides,
device const T *src,
device const T *k,
device T *dst,
uint tid [[ thread_position_in_grid ]]
) {
// src: (b_size, c_in, l_in)
// kernel: (c_in, c_out, l_k)
const size_t l_k = k_dims[2];
const size_t c_out = k_dims[1];
const size_t c_in = src_dims[1];
const size_t l_in = src_dims[2];
if (tid >= src_dims[0] * c_out * l_out) {
return;
}
const size_t b_idx = tid / (l_out * c_out);
const size_t dst_c_idx = (tid / l_out) % c_out;
const size_t out_x = tid % l_out;
const size_t src_idx0 = b_idx * src_strides[0];
A d = 0;
for (int k_x = 0; k_x < (int)l_k; ++k_x) {
// let out_x = inp_x * p.stride + k_x * p.dilation - p.padding;
int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
if (inp_x_stride < 0 || inp_x_stride % stride) {
continue;
}
int inp_x = inp_x_stride / stride;
if (inp_x >= l_in) continue;
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
const size_t src_idx = src_idx0 + src_c_idx * src_strides[1] + inp_x * src_strides[2];
const size_t k_idx = src_c_idx * k_strides[0] + dst_c_idx * k_strides[1] + k_x * k_strides[2];
d += static_cast<A>(src[src_idx]) * static_cast<A>(k[k_idx]);
}
}
dst[tid] = static_cast<T>(d);
}
#define CONVT1D_OP(TYPENAME, TYPEACC, FN_NAME) \
kernel void FN_NAME( \
constant size_t &l_out, \
constant size_t &stride, \
constant size_t &padding, \
constant size_t &out_padding, \
constant size_t &dilation, \
constant size_t *src_dims, \
constant size_t *src_strides, \
constant size_t *k_dims, \
constant size_t *k_strides, \
device const TYPENAME *src, \
device const TYPENAME *k, \
device TYPENAME *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
conv_transpose1d<TYPENAME, TYPEACC>(l_out, stride, padding, out_padding, dilation, src_dims, src_strides, k_dims, k_strides, src, k, dst, tid); \
} \
IM2COL_OP(float, im2col_f32)
IM2COL_OP(uint8_t, im2col_u8)
IM2COL_OP(uint32_t, im2col_u32)
@ -361,4 +431,12 @@ AVGPOOL2D_OP(uint32_t, uint32_t, avg_pool2d_u32)
AVGPOOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8)
#if defined(__HAVE_BFLOAT__)
AVGPOOL2D_OP(bfloat, float, avg_pool2d_bf16)
#endif
CONVT1D_OP(float, float, conv_transpose1d_f32)
CONVT1D_OP(half, float, conv_transpose1d_f16)
CONVT1D_OP(uint8_t, uint8_t, conv_transpose1d_u8)
CONVT1D_OP(uint32_t, uint32_t, conv_transpose1d_u32)
#if defined(__HAVE_BFLOAT__)
CONVT1D_OP(bfloat, float, conv_transpose1d_bf16)
#endif

View File

@ -1859,5 +1859,58 @@ pub fn call_pool2d(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_conv_transpose1d(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
dilation: usize,
stride: usize,
padding: usize,
out_padding: usize,
c_out: usize,
l_out: usize,
b_size: usize,
src_shape: &[usize],
src_strides: &[usize],
kernel_shape: &[usize],
kernel_strides: &[usize],
input: &Buffer,
input_offset: usize,
kernel: &Buffer,
kernel_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let dst_el = c_out * l_out * b_size;
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
l_out,
stride,
padding,
out_padding,
dilation,
src_shape,
src_strides,
kernel_shape,
kernel_strides,
(input, input_offset),
(kernel, kernel_offset),
output
)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[cfg(test)]
mod tests;

View File

@ -1717,3 +1717,219 @@ fn avg_pool2d_u32() {
let expected = vec![2, 3, 4, 6, 7, 8, 10, 11, 12];
assert_eq!(results, expected);
}
fn run_conv_transpose1d<T: Clone>(
input: &[T],
input_shape: &[usize],
input_stride: &[usize],
kernel: &[T],
kernel_shape: &[usize],
kernel_stride: &[usize],
dilation: usize,
stride: usize,
padding: usize,
out_padding: usize,
name: &'static str,
) -> Vec<T> {
let device = device();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let c_out = kernel_shape[1];
let k_size = kernel_shape[2];
let b_size = input_shape[0];
let l_in = input_shape[2];
let l_out = (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1;
let dst_el = c_out * l_out * b_size;
let input = new_buffer(&device, input);
let kernel = new_buffer(&device, kernel);
let output = new_buffer(&device, &vec![0.0f32; dst_el]);
let kernels = Kernels::new();
call_conv_transpose1d(
&device,
command_buffer,
&kernels,
name,
dilation,
stride,
padding,
out_padding,
c_out,
l_out,
b_size,
input_shape,
input_stride,
kernel_shape,
kernel_stride,
&input,
0,
&kernel,
0,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, dst_el)
}
#[test]
fn conv_transpose1d_f32() {
let input = vec![1.0f32, 2.0, 3.0, 4.0];
let input_shape = &[1, 1, 4];
let input_stride = &[4, 4, 1];
let kernel = vec![1.0f32, 2.0, 3.0, 4.0];
let kernel_shape = &[1, 1, 4];
let kernel_stride = &[4, 4, 1];
let results = run_conv_transpose1d(
&input,
input_shape,
input_stride,
&kernel,
kernel_shape,
kernel_stride,
1,
1,
0,
0,
"conv_transpose1d_f32",
);
let expected = vec![1., 4., 10., 20., 25., 24., 16.];
assert_eq!(results, expected);
}
#[test]
fn conv_transpose1d_f16() {
let input: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect();
let input_shape = &[1, 1, 4];
let input_stride = &[4, 4, 1];
let kernel: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect();
let kernel_shape = &[1, 1, 4];
let kernel_stride = &[4, 4, 1];
let results = run_conv_transpose1d(
&input,
input_shape,
input_stride,
&kernel,
kernel_shape,
kernel_stride,
1,
1,
0,
0,
"conv_transpose1d_f16",
);
let expected = vec![1., 4., 10., 20., 25., 24., 16.]
.iter()
.map(|v| f16::from_f32(*v))
.collect::<Vec<_>>();
assert_eq!(results, expected);
}
#[test]
fn conv_transpose1d_bf16() {
let input: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| bf16::from_f32(*v))
.collect();
let input_shape = &[1, 1, 4];
let input_stride = &[4, 4, 1];
let kernel: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| bf16::from_f32(*v))
.collect();
let kernel_shape = &[1, 1, 4];
let kernel_stride = &[4, 4, 1];
let results = run_conv_transpose1d(
&input,
input_shape,
input_stride,
&kernel,
kernel_shape,
kernel_stride,
1,
1,
0,
0,
"conv_transpose1d_bf16",
);
let expected = vec![1., 4., 10., 20., 25., 24., 16.]
.iter()
.map(|v| bf16::from_f32(*v))
.collect::<Vec<_>>();
assert_eq!(results, expected);
}
#[test]
fn conv_transpose1d_u8() {
let input: Vec<u8> = vec![1, 2, 3, 4];
let input_shape = &[1, 1, 4];
let input_stride = &[4, 4, 1];
let kernel: Vec<u8> = vec![1, 2, 3, 4];
let kernel_shape = &[1, 1, 4];
let kernel_stride = &[4, 4, 1];
let results = run_conv_transpose1d(
&input,
input_shape,
input_stride,
&kernel,
kernel_shape,
kernel_stride,
1,
1,
0,
0,
"conv_transpose1d_u8",
);
let expected = vec![1, 4, 10, 20, 25, 24, 16];
assert_eq!(results, expected);
}
#[test]
fn conv_transpose1d_u32() {
let input: Vec<u32> = vec![1, 2, 3, 4];
let input_shape = &[1, 1, 4];
let input_stride = &[4, 4, 1];
let kernel: Vec<u32> = vec![1, 2, 3, 4];
let kernel_shape = &[1, 1, 4];
let kernel_stride = &[4, 4, 1];
let results = run_conv_transpose1d(
&input,
input_shape,
input_stride,
&kernel,
kernel_shape,
kernel_stride,
1,
1,
0,
0,
"conv_transpose1d_u32",
);
let expected = vec![1, 4, 10, 20, 25, 24, 16];
assert_eq!(results, expected);
}