diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 52d1b558..f370f490 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -608,6 +608,34 @@ impl Map1 for Elu { } } +struct Col2Im1D { + stride: usize, +} + +impl Map1 for Col2Im1D { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let (b_size, l_in, c_out, k_size) = layout.shape().dims4()?; + let stride = self.stride; + let l_out = (l_in - 1) * stride + k_size; + + let dst_el = b_size * c_out * l_out; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::("col2im1d"), kernels::CONV)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::(dst_el) }.w()?; + let params = (l_in, l_out, c_out, k_size, b_size, stride, src, &dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } +} + struct Im2Col1D { l_k: usize, stride: usize, @@ -1865,9 +1893,55 @@ impl BackendStorage for CudaStorage { params: &crate::conv::ParamsConvTranspose1D, ) -> Result { let device = self.device().clone(); - let slice = - ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; - Ok(Self { slice, device }) + const USE_COL2IM_CONV1D_TR: bool = true; + + let can_use_col2im = kernel_l.is_contiguous() + && params.dilation == 1 + && params.padding == 0 + && params.output_padding == 0; + if !can_use_col2im || !USE_COL2IM_CONV1D_TR { + let slice = + ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; + return Ok(Self { slice, device }); + } + + let (b_size, c_in, l_in) = l.shape().dims3()?; + let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?; + if !kernel_l.is_contiguous() { + crate::bail!("convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}") + } + if c_in != c_in2 { + crate::bail!( + "convtr1d: shape mismatch on c_in {:?} {:?}", + l.shape(), + kernel_l.shape() + ) + } + let col = { + // This merges the last two dimensions of the kernel together. + let kernel_l_mm = Layout::new( + (b_size, c_in, k_size * c_out).into(), + vec![0, k_size * c_out, 1], + kernel_l.start_offset(), + ); + self.matmul( + kernel, + ( + b_size, + /* m */ l_in, + /* n */ c_out * k_size, + /* k */ c_in, + ), + &l.transpose(1, 2)?, + &kernel_l_mm, + )? + }; + let col_l = Layout::contiguous((b_size, l_in, c_out, k_size)); + Col2Im1D { + stride: params.stride, + } + .map(&col.slice, &device, &col_l)?; + Ok(col) } #[cfg(not(feature = "cudnn"))] diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index fed920f1..df7b04cc 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -51,6 +51,47 @@ __device__ void conv1d( dst[dst_i] = static_cast(d); } +template +__device__ void col2im1d( + const size_t l_in, + const size_t l_out, + const size_t c_out, + const size_t k_size, + const size_t b_size, + const size_t stride, + const T *src, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + // src: (b_size, l_in, c_out, k_size) + // dst: (b_size, c_out, l_out) + if (dst_i >= b_size * c_out * l_out) { + return; + } + const size_t dst_s0 = c_out * l_out; + const size_t dst_s1 = l_out; + + // dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_in_i * stride + k_i + const size_t b_i = dst_i / dst_s0; + const size_t dst_i2 = dst_i - b_i * dst_s0; + const size_t c_i = dst_i2 / dst_s1; + const size_t dst_i3 = dst_i2 - c_i * dst_s1; // l_in_i * stride + k_i + + const size_t src_s0 = c_out * k_size * l_in; + const size_t src_s1 = c_out * k_size; + const size_t src_s2 = k_size; + + dst[dst_i] = 0; + for (size_t k_i = 0; k_i < min(dst_i3, k_size); ++k_i) { + const size_t l_in_i_times_stride = dst_i3 - k_i; + const size_t l_in_i = l_in_i_times_stride / stride; + const size_t src_i = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i; + if (l_in_i * stride == l_in_i_times_stride) { + dst[dst_i] += src[src_i]; + } + } +} + template __device__ void im2col1d( const size_t dst_numel,