Optimize the cuda conv transpose1d kernel.

This commit is contained in:
laurent
2024-03-17 19:28:37 +01:00
parent ce9fbc3682
commit 42ae70c458
2 changed files with 118 additions and 3 deletions

View File

@ -608,6 +608,34 @@ impl Map1 for Elu {
}
}
struct Col2Im1D {
stride: usize,
}
impl Map1 for Col2Im1D {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let (b_size, l_in, c_out, k_size) = layout.shape().dims4()?;
let stride = self.stride;
let l_out = (l_in - 1) * stride + k_size;
let dst_el = b_size * c_out * l_out;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let params = (l_in, l_out, c_out, k_size, b_size, stride, src, &dst);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
}
struct Im2Col1D {
l_k: usize,
stride: usize,
@ -1865,9 +1893,55 @@ impl BackendStorage for CudaStorage {
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
let device = self.device().clone();
const USE_COL2IM_CONV1D_TR: bool = true;
let can_use_col2im = kernel_l.is_contiguous()
&& params.dilation == 1
&& params.padding == 0
&& params.output_padding == 0;
if !can_use_col2im || !USE_COL2IM_CONV1D_TR {
let slice =
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
Ok(Self { slice, device })
return Ok(Self { slice, device });
}
let (b_size, c_in, l_in) = l.shape().dims3()?;
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
if !kernel_l.is_contiguous() {
crate::bail!("convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}")
}
if c_in != c_in2 {
crate::bail!(
"convtr1d: shape mismatch on c_in {:?} {:?}",
l.shape(),
kernel_l.shape()
)
}
let col = {
// This merges the last two dimensions of the kernel together.
let kernel_l_mm = Layout::new(
(b_size, c_in, k_size * c_out).into(),
vec![0, k_size * c_out, 1],
kernel_l.start_offset(),
);
self.matmul(
kernel,
(
b_size,
/* m */ l_in,
/* n */ c_out * k_size,
/* k */ c_in,
),
&l.transpose(1, 2)?,
&kernel_l_mm,
)?
};
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
Col2Im1D {
stride: params.stride,
}
.map(&col.slice, &device, &col_l)?;
Ok(col)
}
#[cfg(not(feature = "cudnn"))]

View File

@ -51,6 +51,47 @@ __device__ void conv1d(
dst[dst_i] = static_cast<T>(d);
}
template <typename T>
__device__ void col2im1d(
const size_t l_in,
const size_t l_out,
const size_t c_out,
const size_t k_size,
const size_t b_size,
const size_t stride,
const T *src,
T *dst
) {
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
// src: (b_size, l_in, c_out, k_size)
// dst: (b_size, c_out, l_out)
if (dst_i >= b_size * c_out * l_out) {
return;
}
const size_t dst_s0 = c_out * l_out;
const size_t dst_s1 = l_out;
// dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_in_i * stride + k_i
const size_t b_i = dst_i / dst_s0;
const size_t dst_i2 = dst_i - b_i * dst_s0;
const size_t c_i = dst_i2 / dst_s1;
const size_t dst_i3 = dst_i2 - c_i * dst_s1; // l_in_i * stride + k_i
const size_t src_s0 = c_out * k_size * l_in;
const size_t src_s1 = c_out * k_size;
const size_t src_s2 = k_size;
dst[dst_i] = 0;
for (size_t k_i = 0; k_i < min(dst_i3, k_size); ++k_i) {
const size_t l_in_i_times_stride = dst_i3 - k_i;
const size_t l_in_i = l_in_i_times_stride / stride;
const size_t src_i = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
if (l_in_i * stride == l_in_i_times_stride) {
dst[dst_i] += src[src_i];
}
}
}
template <typename T>
__device__ void im2col1d(
const size_t dst_numel,