Compare commits

...

6 Commits

Author SHA1 Message Date
53f951f6e2 Merge remote-tracking branch 'origin/main' into cuda-conv-tr1d 2024-03-17 21:17:56 +01:00
52e70856ea Tweaks. 2024-03-17 20:48:21 +01:00
3cae6f5e9a Zero padding. 2024-03-17 20:24:34 +01:00
dffafd1049 Small optimization. 2024-03-17 20:15:51 +01:00
75f2aea5fd Fix the kernel. 2024-03-17 19:55:54 +01:00
42ae70c458 Optimize the cuda conv transpose1d kernel. 2024-03-17 19:28:37 +01:00
2 changed files with 137 additions and 9 deletions

View File

@ -608,6 +608,34 @@ impl Map1 for Elu {
}
}
struct Col2Im1D {
stride: usize,
}
impl Map1 for Col2Im1D {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let (b_size, l_in, c_out, k_size) = layout.shape().dims4()?;
let stride = self.stride;
let l_out = (l_in - 1) * stride + k_size;
let dst_el = b_size * c_out * l_out;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let params = (l_in, l_out, c_out, k_size, b_size, stride, src, &dst);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
}
struct Im2Col1D {
l_k: usize,
stride: usize,
@ -1865,8 +1893,54 @@ impl BackendStorage for CudaStorage {
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
let device = self.device().clone();
let slice =
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
const USE_COL2IM_CONV1D_TR: bool = true;
let can_use_col2im = kernel_l.is_contiguous()
&& params.dilation == 1
&& params.padding == 0
&& params.output_padding == 0;
if !can_use_col2im || !USE_COL2IM_CONV1D_TR {
let slice =
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
return Ok(Self { slice, device });
}
let (b_size, c_in, l_in) = l.shape().dims3()?;
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
if !kernel_l.is_contiguous() {
crate::bail!("convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}")
}
if c_in != c_in2 {
crate::bail!(
"convtr1d: shape mismatch on c_in {:?} {:?}",
l.shape(),
kernel_l.shape()
)
}
let col = {
// This merges the last two dimensions of the kernel together.
let kernel_l_mm = Layout::new(
(b_size, c_in, k_size * c_out).into(),
vec![0, k_size * c_out, 1],
kernel_l.start_offset(),
);
self.matmul(
kernel,
(
b_size,
/* m */ l_in,
/* n */ c_out * k_size,
/* k */ c_in,
),
&l.transpose(1, 2)?,
&kernel_l_mm,
)?
};
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
let slice = Col2Im1D {
stride: params.stride,
}
.map(&col.slice, &device, &col_l)?;
Ok(Self { slice, device })
}

View File

@ -51,6 +51,48 @@ __device__ void conv1d(
dst[dst_i] = static_cast<T>(d);
}
template <typename T>
__device__ void col2im1d(
const size_t l_in,
const size_t l_out,
const size_t c_out,
const size_t k_size,
const size_t b_size,
const size_t stride,
const T *src,
T *dst
) {
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
// src: (b_size, l_in, c_out, k_size)
// dst: (b_size, c_out, l_out)
if (dst_i >= b_size * c_out * l_out) {
return;
}
const size_t dst_s0 = c_out * l_out;
const size_t dst_s1 = l_out;
// dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_in_i * stride + k_i
const size_t b_i = dst_i / dst_s0;
const size_t dst_i2 = dst_i - b_i * dst_s0;
const size_t c_i = dst_i2 / dst_s1;
const size_t dst_i3 = dst_i2 - c_i * dst_s1; // l_in_i * stride + k_i
const size_t src_s0 = c_out * k_size * l_in;
const size_t src_s1 = c_out * k_size;
const size_t src_s2 = k_size;
T d = 0;
for (size_t k_i = 0; k_i < min(dst_i3 + 1, k_size); ++k_i) {
const size_t l_in_i_times_stride = dst_i3 - k_i;
const size_t l_in_i = l_in_i_times_stride / stride;
const size_t src_i = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
if (l_in_i * stride == l_in_i_times_stride && l_in_i < l_in) {
d += src[src_i];
}
}
dst[dst_i] = d;
}
template <typename T>
__device__ void im2col1d(
const size_t dst_numel,
@ -527,7 +569,7 @@ extern "C" __global__ void FN_NAME( \
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \
} \
#define IM2COL1D_OP(TYPENAME, FN_NAME) \
#define IM2COL1D_OP(TYPENAME, FN_NAME, FN_NAME2) \
extern "C" __global__ void FN_NAME( \
const size_t dst_numel, \
const size_t l_out, \
@ -541,6 +583,18 @@ extern "C" __global__ void FN_NAME( \
) { \
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
} \
extern "C" __global__ void FN_NAME2( \
const size_t l_in, \
const size_t l_out, \
const size_t c_out, \
const size_t k_size, \
const size_t b_size, \
const size_t stride, \
const TYPENAME *src, \
TYPENAME *dst \
) { \
col2im1d<TYPENAME>(l_in, l_out, c_out, k_size, b_size, stride, src, dst); \
} \
#define IM2COL_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
@ -642,7 +696,7 @@ AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
IM2COL_OP(__nv_bfloat16, im2col_bf16)
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16, col2im1d_bf16)
#endif
#if __CUDA_ARCH__ >= 530
@ -654,7 +708,7 @@ AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
MAX_POOL2D_OP(__half, max_pool2d_f16)
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
IM2COL_OP(__half, im2col_f16)
IM2COL1D_OP(__half, im2col1d_f16)
IM2COL1D_OP(__half, im2col1d_f16, col2im1d_f16)
#endif
CONV1D_OP(float, float, conv1d_f32)
@ -697,7 +751,7 @@ IM2COL_OP(double, im2col_f64)
IM2COL_OP(uint8_t, im2col_u8)
IM2COL_OP(uint32_t, im2col_u32)
IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(double, im2col1d_f64)
IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32)
IM2COL1D_OP(float, im2col1d_f32, col2im1d_f32)
IM2COL1D_OP(double, im2col1d_f64, col2im1d_f64)
IM2COL1D_OP(uint8_t, im2col1d_u8, col2im1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32, col2im1d_u32)