mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add support for conv_transpose1d for metal backend (#1874)
* first attempt * progress * integrate into metal backend * finish and get test passing * add other dtype support * update transpose1d dtypes supported
This commit is contained in:
@ -948,12 +948,54 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &ParamsConvTranspose1D,
|
||||
layout: &Layout,
|
||||
k: &Self,
|
||||
k_layout: &Layout,
|
||||
params: &ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
crate::bail!("Metal conv_transpose1d not implemented")
|
||||
let device = self.device().clone();
|
||||
|
||||
let l_out = params.l_out();
|
||||
let dst_el = params.c_out * l_out * params.b_size;
|
||||
|
||||
let dst_el = params.c_out * l_out * params.b_size;
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
|
||||
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "conv_transpose1d_f32",
|
||||
DType::F16 => "conv_transpose1d_f16",
|
||||
DType::BF16 => "conv_transpose1d_bf16",
|
||||
DType::U32 => "conv_transpose1d_u32",
|
||||
DType::U8 => "conv_transpose1d_u8",
|
||||
dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_conv_transpose1d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
params.dilation,
|
||||
params.stride,
|
||||
params.padding,
|
||||
params.output_padding,
|
||||
params.c_out,
|
||||
l_out,
|
||||
params.b_size,
|
||||
layout.dims(),
|
||||
layout.stride(),
|
||||
k_layout.dims(),
|
||||
k_layout.stride(),
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&k.buffer,
|
||||
k_layout.start_offset() * k.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
|
@ -54,11 +54,6 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
|
||||
// conv-transposes are not implemented for metal.
|
||||
if dev.is_metal() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let w = w.transpose(0, 1)?;
|
||||
// The CPU kernels applied in the contiguous and non contiguous cases are different.
|
||||
for w in [w.clone(), w.contiguous()?] {
|
||||
|
@ -335,6 +335,76 @@ kernel void FN_NAME( \
|
||||
max_pool2d<TYPENAME>(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \
|
||||
} \
|
||||
|
||||
|
||||
// Naive implementation of conv_transpose1d.
|
||||
template <typename T, typename A>
|
||||
METAL_FUNC void conv_transpose1d(
|
||||
constant size_t &l_out,
|
||||
constant size_t &stride,
|
||||
constant size_t &padding,
|
||||
constant size_t &out_padding,
|
||||
constant size_t &dilation,
|
||||
constant size_t *src_dims,
|
||||
constant size_t *src_strides,
|
||||
constant size_t *k_dims,
|
||||
constant size_t *k_strides,
|
||||
device const T *src,
|
||||
device const T *k,
|
||||
device T *dst,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
// src: (b_size, c_in, l_in)
|
||||
// kernel: (c_in, c_out, l_k)
|
||||
const size_t l_k = k_dims[2];
|
||||
const size_t c_out = k_dims[1];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t l_in = src_dims[2];
|
||||
if (tid >= src_dims[0] * c_out * l_out) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t b_idx = tid / (l_out * c_out);
|
||||
const size_t dst_c_idx = (tid / l_out) % c_out;
|
||||
const size_t out_x = tid % l_out;
|
||||
|
||||
const size_t src_idx0 = b_idx * src_strides[0];
|
||||
A d = 0;
|
||||
for (int k_x = 0; k_x < (int)l_k; ++k_x) {
|
||||
// let out_x = inp_x * p.stride + k_x * p.dilation - p.padding;
|
||||
int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
|
||||
if (inp_x_stride < 0 || inp_x_stride % stride) {
|
||||
continue;
|
||||
}
|
||||
int inp_x = inp_x_stride / stride;
|
||||
if (inp_x >= l_in) continue;
|
||||
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
|
||||
const size_t src_idx = src_idx0 + src_c_idx * src_strides[1] + inp_x * src_strides[2];
|
||||
const size_t k_idx = src_c_idx * k_strides[0] + dst_c_idx * k_strides[1] + k_x * k_strides[2];
|
||||
d += static_cast<A>(src[src_idx]) * static_cast<A>(k[k_idx]);
|
||||
}
|
||||
}
|
||||
dst[tid] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
#define CONVT1D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &l_out, \
|
||||
constant size_t &stride, \
|
||||
constant size_t &padding, \
|
||||
constant size_t &out_padding, \
|
||||
constant size_t &dilation, \
|
||||
constant size_t *src_dims, \
|
||||
constant size_t *src_strides, \
|
||||
constant size_t *k_dims, \
|
||||
constant size_t *k_strides, \
|
||||
device const TYPENAME *src, \
|
||||
device const TYPENAME *k, \
|
||||
device TYPENAME *dst, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
conv_transpose1d<TYPENAME, TYPEACC>(l_out, stride, padding, out_padding, dilation, src_dims, src_strides, k_dims, k_strides, src, k, dst, tid); \
|
||||
} \
|
||||
|
||||
IM2COL_OP(float, im2col_f32)
|
||||
IM2COL_OP(uint8_t, im2col_u8)
|
||||
IM2COL_OP(uint32_t, im2col_u32)
|
||||
@ -361,4 +431,12 @@ AVGPOOL2D_OP(uint32_t, uint32_t, avg_pool2d_u32)
|
||||
AVGPOOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
AVGPOOL2D_OP(bfloat, float, avg_pool2d_bf16)
|
||||
#endif
|
||||
|
||||
CONVT1D_OP(float, float, conv_transpose1d_f32)
|
||||
CONVT1D_OP(half, float, conv_transpose1d_f16)
|
||||
CONVT1D_OP(uint8_t, uint8_t, conv_transpose1d_u8)
|
||||
CONVT1D_OP(uint32_t, uint32_t, conv_transpose1d_u32)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CONVT1D_OP(bfloat, float, conv_transpose1d_bf16)
|
||||
#endif
|
@ -1859,5 +1859,58 @@ pub fn call_pool2d(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_conv_transpose1d(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
dilation: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
out_padding: usize,
|
||||
c_out: usize,
|
||||
l_out: usize,
|
||||
b_size: usize,
|
||||
src_shape: &[usize],
|
||||
src_strides: &[usize],
|
||||
kernel_shape: &[usize],
|
||||
kernel_strides: &[usize],
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
kernel: &Buffer,
|
||||
kernel_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let dst_el = c_out * l_out * b_size;
|
||||
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
l_out,
|
||||
stride,
|
||||
padding,
|
||||
out_padding,
|
||||
dilation,
|
||||
src_shape,
|
||||
src_strides,
|
||||
kernel_shape,
|
||||
kernel_strides,
|
||||
(input, input_offset),
|
||||
(kernel, kernel_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
@ -1717,3 +1717,219 @@ fn avg_pool2d_u32() {
|
||||
let expected = vec![2, 3, 4, 6, 7, 8, 10, 11, 12];
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
||||
fn run_conv_transpose1d<T: Clone>(
|
||||
input: &[T],
|
||||
input_shape: &[usize],
|
||||
input_stride: &[usize],
|
||||
kernel: &[T],
|
||||
kernel_shape: &[usize],
|
||||
kernel_stride: &[usize],
|
||||
dilation: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
out_padding: usize,
|
||||
name: &'static str,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
let c_out = kernel_shape[1];
|
||||
let k_size = kernel_shape[2];
|
||||
let b_size = input_shape[0];
|
||||
let l_in = input_shape[2];
|
||||
let l_out = (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1;
|
||||
let dst_el = c_out * l_out * b_size;
|
||||
|
||||
let input = new_buffer(&device, input);
|
||||
let kernel = new_buffer(&device, kernel);
|
||||
let output = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||
let kernels = Kernels::new();
|
||||
|
||||
call_conv_transpose1d(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
dilation,
|
||||
stride,
|
||||
padding,
|
||||
out_padding,
|
||||
c_out,
|
||||
l_out,
|
||||
b_size,
|
||||
input_shape,
|
||||
input_stride,
|
||||
kernel_shape,
|
||||
kernel_stride,
|
||||
&input,
|
||||
0,
|
||||
&kernel,
|
||||
0,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
read_to_vec(&output, dst_el)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv_transpose1d_f32() {
|
||||
let input = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let input_shape = &[1, 1, 4];
|
||||
let input_stride = &[4, 4, 1];
|
||||
|
||||
let kernel = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let kernel_shape = &[1, 1, 4];
|
||||
let kernel_stride = &[4, 4, 1];
|
||||
|
||||
let results = run_conv_transpose1d(
|
||||
&input,
|
||||
input_shape,
|
||||
input_stride,
|
||||
&kernel,
|
||||
kernel_shape,
|
||||
kernel_stride,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
"conv_transpose1d_f32",
|
||||
);
|
||||
|
||||
let expected = vec![1., 4., 10., 20., 25., 24., 16.];
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv_transpose1d_f16() {
|
||||
let input: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
let input_shape = &[1, 1, 4];
|
||||
let input_stride = &[4, 4, 1];
|
||||
|
||||
let kernel: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
let kernel_shape = &[1, 1, 4];
|
||||
let kernel_stride = &[4, 4, 1];
|
||||
|
||||
let results = run_conv_transpose1d(
|
||||
&input,
|
||||
input_shape,
|
||||
input_stride,
|
||||
&kernel,
|
||||
kernel_shape,
|
||||
kernel_stride,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
"conv_transpose1d_f16",
|
||||
);
|
||||
|
||||
let expected = vec![1., 4., 10., 20., 25., 24., 16.]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv_transpose1d_bf16() {
|
||||
let input: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect();
|
||||
let input_shape = &[1, 1, 4];
|
||||
let input_stride = &[4, 4, 1];
|
||||
|
||||
let kernel: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect();
|
||||
let kernel_shape = &[1, 1, 4];
|
||||
let kernel_stride = &[4, 4, 1];
|
||||
|
||||
let results = run_conv_transpose1d(
|
||||
&input,
|
||||
input_shape,
|
||||
input_stride,
|
||||
&kernel,
|
||||
kernel_shape,
|
||||
kernel_stride,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
"conv_transpose1d_bf16",
|
||||
);
|
||||
|
||||
let expected = vec![1., 4., 10., 20., 25., 24., 16.]
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv_transpose1d_u8() {
|
||||
let input: Vec<u8> = vec![1, 2, 3, 4];
|
||||
let input_shape = &[1, 1, 4];
|
||||
let input_stride = &[4, 4, 1];
|
||||
|
||||
let kernel: Vec<u8> = vec![1, 2, 3, 4];
|
||||
let kernel_shape = &[1, 1, 4];
|
||||
let kernel_stride = &[4, 4, 1];
|
||||
|
||||
let results = run_conv_transpose1d(
|
||||
&input,
|
||||
input_shape,
|
||||
input_stride,
|
||||
&kernel,
|
||||
kernel_shape,
|
||||
kernel_stride,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
"conv_transpose1d_u8",
|
||||
);
|
||||
|
||||
let expected = vec![1, 4, 10, 20, 25, 24, 16];
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv_transpose1d_u32() {
|
||||
let input: Vec<u32> = vec![1, 2, 3, 4];
|
||||
let input_shape = &[1, 1, 4];
|
||||
let input_stride = &[4, 4, 1];
|
||||
|
||||
let kernel: Vec<u32> = vec![1, 2, 3, 4];
|
||||
let kernel_shape = &[1, 1, 4];
|
||||
let kernel_stride = &[4, 4, 1];
|
||||
|
||||
let results = run_conv_transpose1d(
|
||||
&input,
|
||||
input_shape,
|
||||
input_stride,
|
||||
&kernel,
|
||||
kernel_shape,
|
||||
kernel_stride,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
"conv_transpose1d_u32",
|
||||
);
|
||||
|
||||
let expected = vec![1, 4, 10, 20, 25, 24, 16];
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
Reference in New Issue
Block a user