diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index cf354f45..303d69ff 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -948,12 +948,54 @@ impl BackendStorage for MetalStorage { fn conv_transpose1d( &self, - _l: &Layout, - _kernel: &Self, - _kernel_l: &Layout, - _params: &ParamsConvTranspose1D, + layout: &Layout, + k: &Self, + k_layout: &Layout, + params: &ParamsConvTranspose1D, ) -> Result { - crate::bail!("Metal conv_transpose1d not implemented") + let device = self.device().clone(); + + let l_out = params.l_out(); + let dst_el = params.c_out * l_out * params.b_size; + + let dst_el = params.c_out * l_out * params.b_size; + let buffer = self + .device + .new_buffer(dst_el, self.dtype, "conv_transpose1d")?; + + let command_buffer = self.device.command_buffer()?; + let name = match self.dtype { + DType::F32 => "conv_transpose1d_f32", + DType::F16 => "conv_transpose1d_f16", + DType::BF16 => "conv_transpose1d_bf16", + DType::U32 => "conv_transpose1d_u32", + DType::U8 => "conv_transpose1d_u8", + dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"), + }; + candle_metal_kernels::call_conv_transpose1d( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + params.dilation, + params.stride, + params.padding, + params.output_padding, + params.c_out, + l_out, + params.b_size, + layout.dims(), + layout.stride(), + k_layout.dims(), + k_layout.stride(), + &self.buffer, + layout.start_offset() * self.dtype.size_in_bytes(), + &k.buffer, + k_layout.start_offset() * k.dtype.size_in_bytes(), + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) } fn conv2d( diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index ba60b778..71bf65be 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -54,11 +54,6 @@ fn conv1d(dev: &Device) -> Result<()> { [2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352] ); - // conv-transposes are not implemented for metal. - if dev.is_metal() { - return Ok(()); - } - let w = w.transpose(0, 1)?; // The CPU kernels applied in the contiguous and non contiguous cases are different. for w in [w.clone(), w.contiguous()?] { diff --git a/candle-metal-kernels/src/conv.metal b/candle-metal-kernels/src/conv.metal index 7f7a75cf..a258ae58 100644 --- a/candle-metal-kernels/src/conv.metal +++ b/candle-metal-kernels/src/conv.metal @@ -335,6 +335,76 @@ kernel void FN_NAME( \ max_pool2d(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \ } \ + +// Naive implementation of conv_transpose1d. +template +METAL_FUNC void conv_transpose1d( + constant size_t &l_out, + constant size_t &stride, + constant size_t &padding, + constant size_t &out_padding, + constant size_t &dilation, + constant size_t *src_dims, + constant size_t *src_strides, + constant size_t *k_dims, + constant size_t *k_strides, + device const T *src, + device const T *k, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + // src: (b_size, c_in, l_in) + // kernel: (c_in, c_out, l_k) + const size_t l_k = k_dims[2]; + const size_t c_out = k_dims[1]; + const size_t c_in = src_dims[1]; + const size_t l_in = src_dims[2]; + if (tid >= src_dims[0] * c_out * l_out) { + return; + } + + const size_t b_idx = tid / (l_out * c_out); + const size_t dst_c_idx = (tid / l_out) % c_out; + const size_t out_x = tid % l_out; + + const size_t src_idx0 = b_idx * src_strides[0]; + A d = 0; + for (int k_x = 0; k_x < (int)l_k; ++k_x) { + // let out_x = inp_x * p.stride + k_x * p.dilation - p.padding; + int inp_x_stride = (int)(out_x + padding) - k_x * dilation; + if (inp_x_stride < 0 || inp_x_stride % stride) { + continue; + } + int inp_x = inp_x_stride / stride; + if (inp_x >= l_in) continue; + for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) { + const size_t src_idx = src_idx0 + src_c_idx * src_strides[1] + inp_x * src_strides[2]; + const size_t k_idx = src_c_idx * k_strides[0] + dst_c_idx * k_strides[1] + k_x * k_strides[2]; + d += static_cast(src[src_idx]) * static_cast(k[k_idx]); + } + } + dst[tid] = static_cast(d); +} + +#define CONVT1D_OP(TYPENAME, TYPEACC, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &l_out, \ + constant size_t &stride, \ + constant size_t &padding, \ + constant size_t &out_padding, \ + constant size_t &dilation, \ + constant size_t *src_dims, \ + constant size_t *src_strides, \ + constant size_t *k_dims, \ + constant size_t *k_strides, \ + device const TYPENAME *src, \ + device const TYPENAME *k, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + conv_transpose1d(l_out, stride, padding, out_padding, dilation, src_dims, src_strides, k_dims, k_strides, src, k, dst, tid); \ +} \ + IM2COL_OP(float, im2col_f32) IM2COL_OP(uint8_t, im2col_u8) IM2COL_OP(uint32_t, im2col_u32) @@ -361,4 +431,12 @@ AVGPOOL2D_OP(uint32_t, uint32_t, avg_pool2d_u32) AVGPOOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8) #if defined(__HAVE_BFLOAT__) AVGPOOL2D_OP(bfloat, float, avg_pool2d_bf16) +#endif + +CONVT1D_OP(float, float, conv_transpose1d_f32) +CONVT1D_OP(half, float, conv_transpose1d_f16) +CONVT1D_OP(uint8_t, uint8_t, conv_transpose1d_u8) +CONVT1D_OP(uint32_t, uint32_t, conv_transpose1d_u32) +#if defined(__HAVE_BFLOAT__) +CONVT1D_OP(bfloat, float, conv_transpose1d_bf16) #endif \ No newline at end of file diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 1161501f..f12463a4 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1859,5 +1859,58 @@ pub fn call_pool2d( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_conv_transpose1d( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + dilation: usize, + stride: usize, + padding: usize, + out_padding: usize, + c_out: usize, + l_out: usize, + b_size: usize, + src_shape: &[usize], + src_strides: &[usize], + kernel_shape: &[usize], + kernel_strides: &[usize], + input: &Buffer, + input_offset: usize, + kernel: &Buffer, + kernel_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let dst_el = c_out * l_out * b_size; + let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + l_out, + stride, + padding, + out_padding, + dilation, + src_shape, + src_strides, + kernel_shape, + kernel_strides, + (input, input_offset), + (kernel, kernel_offset), + output + ) + ); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(kernel, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 19e160dd..5045a4a3 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1717,3 +1717,219 @@ fn avg_pool2d_u32() { let expected = vec![2, 3, 4, 6, 7, 8, 10, 11, 12]; assert_eq!(results, expected); } + +fn run_conv_transpose1d( + input: &[T], + input_shape: &[usize], + input_stride: &[usize], + kernel: &[T], + kernel_shape: &[usize], + kernel_stride: &[usize], + dilation: usize, + stride: usize, + padding: usize, + out_padding: usize, + name: &'static str, +) -> Vec { + let device = device(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let c_out = kernel_shape[1]; + let k_size = kernel_shape[2]; + let b_size = input_shape[0]; + let l_in = input_shape[2]; + let l_out = (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1; + let dst_el = c_out * l_out * b_size; + + let input = new_buffer(&device, input); + let kernel = new_buffer(&device, kernel); + let output = new_buffer(&device, &vec![0.0f32; dst_el]); + let kernels = Kernels::new(); + + call_conv_transpose1d( + &device, + command_buffer, + &kernels, + name, + dilation, + stride, + padding, + out_padding, + c_out, + l_out, + b_size, + input_shape, + input_stride, + kernel_shape, + kernel_stride, + &input, + 0, + &kernel, + 0, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, dst_el) +} + +#[test] +fn conv_transpose1d_f32() { + let input = vec![1.0f32, 2.0, 3.0, 4.0]; + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel = vec![1.0f32, 2.0, 3.0, 4.0]; + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_f32", + ); + + let expected = vec![1., 4., 10., 20., 25., 24., 16.]; + assert_eq!(results, expected); +} + +#[test] +fn conv_transpose1d_f16() { + let input: Vec = vec![1.0, 2.0, 3.0, 4.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel: Vec = vec![1.0, 2.0, 3.0, 4.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_f16", + ); + + let expected = vec![1., 4., 10., 20., 25., 24., 16.] + .iter() + .map(|v| f16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); +} + +#[test] +fn conv_transpose1d_bf16() { + let input: Vec = vec![1.0, 2.0, 3.0, 4.0] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect(); + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel: Vec = vec![1.0, 2.0, 3.0, 4.0] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect(); + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_bf16", + ); + + let expected = vec![1., 4., 10., 20., 25., 24., 16.] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); +} + +#[test] +fn conv_transpose1d_u8() { + let input: Vec = vec![1, 2, 3, 4]; + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel: Vec = vec![1, 2, 3, 4]; + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_u8", + ); + + let expected = vec![1, 4, 10, 20, 25, 24, 16]; + assert_eq!(results, expected); +} + +#[test] +fn conv_transpose1d_u32() { + let input: Vec = vec![1, 2, 3, 4]; + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel: Vec = vec![1, 2, 3, 4]; + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_u32", + ); + + let expected = vec![1, 4, 10, 20, 25, 24, 16]; + assert_eq!(results, expected); +}