im2col version of the conv1d kernel. (#815)

* im2col version of the cuda conv1d kernel.

* im2col version of the conv1d cpu kernel.
This commit is contained in:
Laurent Mazare
2023-09-11 14:40:09 +01:00
committed by GitHub
parent 5c35fbbb13
commit dbd4561416
3 changed files with 249 additions and 4 deletions

View File

@ -51,6 +51,53 @@ __device__ void conv1d(
dst[dst_i] = static_cast<T>(d);
}
template <typename T>
__device__ void im2col1d(
const size_t dst_numel,
const size_t l_out,
const size_t l_k,
const size_t stride,
const size_t padding,
const size_t dilation,
const size_t *info,
const T *src,
T *dst
) {
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
// dst: (b_size, l_out, c_in, l_k)
// src: (b_size, c_in, l_in)
if (dst_i >= dst_numel) {
return;
}
const size_t *src_dims = info;
const size_t *src_s = info + 3;
const size_t b_in = src_dims[0];
const size_t c_in = src_dims[1];
const size_t l_in = src_dims[2];
const size_t dst_s2 = l_k;
const size_t dst_s1 = c_in * dst_s2;
const size_t dst_s0 = l_out * dst_s1;
size_t tmp_dst_i = dst_i;
const size_t b_idx = tmp_dst_i / dst_s0;
tmp_dst_i -= b_idx * dst_s0;
const size_t l_idx = tmp_dst_i / dst_s1;
tmp_dst_i -= l_idx * dst_s1;
const size_t c_idx = tmp_dst_i / dst_s2;
tmp_dst_i -= c_idx * dst_s2;
const size_t l_k_idx = tmp_dst_i;
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
dst[dst_i] = static_cast<T>(0);
}
else {
src_l_idx -= padding;
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
dst[dst_i] = src[src_i];
}
}
template <typename T>
__device__ void im2col(
const size_t dst_numel,
@ -78,7 +125,7 @@ __device__ void im2col(
const size_t h_in = src_dims[2];
const size_t w_in = src_dims[3];
const size_t dst_s4 = w_k;
const size_t dst_s4 = w_k;
const size_t dst_s3 = h_k * dst_s4;
const size_t dst_s2 = c_in * dst_s3;
const size_t dst_s1 = w_out * dst_s2;
@ -428,6 +475,21 @@ extern "C" __global__ void FN_NAME( \
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \
} \
#define IM2COL1D_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t dst_numel, \
const size_t l_out, \
const size_t l_k, \
const size_t stride, \
const size_t padding, \
const size_t dilation, \
const size_t *info, \
const TYPENAME *src, \
TYPENAME *dst \
) { \
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
} \
#define IM2COL_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t dst_numel, \
@ -511,6 +573,7 @@ AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
IM2COL_OP(__nv_bfloat16, im2col_bf16)
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
#endif
#if __CUDA_ARCH__ >= 530
@ -521,6 +584,7 @@ AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
MAX_POOL2D_OP(__half, max_pool2d_f16)
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
IM2COL_OP(__half, im2col_f16)
IM2COL1D_OP(__half, im2col1d_f16)
#endif
CONV1D_OP(float, float, conv1d_f32)
@ -557,3 +621,8 @@ IM2COL_OP(float, im2col_f32)
IM2COL_OP(double, im2col_f64)
IM2COL_OP(uint8_t, im2col_u8)
IM2COL_OP(uint32_t, im2col_u32)
IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(double, im2col1d_f64)
IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32)