mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
More efficient cuda implementation for ConvTranspose1d. (#2211)
* More efficient cuda implementation for ConvTranspose1d. * Small tweak.
This commit is contained in:
@ -97,6 +97,50 @@ __device__ void im2col1d(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void col2im1d(
|
||||
const size_t dst_el,
|
||||
const size_t l_out,
|
||||
const size_t l_in,
|
||||
const size_t c_out,
|
||||
const size_t k_size,
|
||||
const size_t stride,
|
||||
const T *src,
|
||||
T *dst
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// src: (b_size, l_in, c_out, l_k)
|
||||
// dst: (b_size, c_out, l_out)
|
||||
if (dst_i >= dst_el) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t dst_s0 = c_out * l_out;
|
||||
const size_t dst_s1 = l_out;
|
||||
const size_t src_s0 = c_out * k_size * l_in;
|
||||
const size_t src_s1 = c_out * k_size;
|
||||
const size_t src_s2 = k_size;
|
||||
|
||||
size_t tmp_dst_i = dst_i;
|
||||
const size_t b_idx = tmp_dst_i / dst_s0;
|
||||
tmp_dst_i -= b_idx * dst_s0;
|
||||
const size_t c_idx = tmp_dst_i / dst_s1;
|
||||
tmp_dst_i -= c_idx * dst_s1;
|
||||
const int l_out_idx = tmp_dst_i;
|
||||
|
||||
dst[dst_i] = static_cast<T>(0);
|
||||
|
||||
int l_in_idx = l_out_idx / stride;
|
||||
int k0 = l_out_idx - l_in_idx * stride;
|
||||
// l_out_idx = l_in_idx * stride + k0
|
||||
for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) {
|
||||
if (l_in_idx < l_in) {
|
||||
const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0;
|
||||
dst[dst_i] += src[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void im2col(
|
||||
const size_t dst_numel,
|
||||
@ -542,6 +586,20 @@ extern "C" __global__ void FN_NAME( \
|
||||
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
|
||||
} \
|
||||
|
||||
#define COL2IM1D_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t dst_el, \
|
||||
const size_t l_out, \
|
||||
const size_t l_in, \
|
||||
const size_t c_out, \
|
||||
const size_t k_size, \
|
||||
const size_t stride, \
|
||||
const TYPENAME *src, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
col2im1d<TYPENAME>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst); \
|
||||
} \
|
||||
|
||||
#define IM2COL_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t dst_numel, \
|
||||
@ -643,6 +701,7 @@ MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
||||
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
|
||||
IM2COL_OP(__nv_bfloat16, im2col_bf16)
|
||||
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
|
||||
COL2IM1D_OP(__nv_bfloat16, col2im1d_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
@ -655,6 +714,7 @@ MAX_POOL2D_OP(__half, max_pool2d_f16)
|
||||
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
|
||||
IM2COL_OP(__half, im2col_f16)
|
||||
IM2COL1D_OP(__half, im2col1d_f16)
|
||||
COL2IM1D_OP(__half, col2im1d_f16)
|
||||
#endif
|
||||
|
||||
CONV1D_OP(float, float, conv1d_f32)
|
||||
@ -701,3 +761,8 @@ IM2COL1D_OP(float, im2col1d_f32)
|
||||
IM2COL1D_OP(double, im2col1d_f64)
|
||||
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
||||
|
||||
COL2IM1D_OP(float, col2im1d_f32)
|
||||
COL2IM1D_OP(double, col2im1d_f64)
|
||||
COL2IM1D_OP(uint8_t, col2im1d_u8)
|
||||
COL2IM1D_OP(uint32_t, col2im1d_u32)
|
||||
|
Reference in New Issue
Block a user