mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
im2col based conv2d (#802)
* im2col implementation for conv2d. * Fix for the im2col implementation to match the current conv2d. * Small optimization. * Add a cuda kernel. * Handle arbitrary layouts. * Im2Col cuda code.
This commit is contained in:
@ -51,6 +51,71 @@ __device__ void conv1d(
|
||||
dst[dst_i] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void im2col(
|
||||
const size_t dst_numel,
|
||||
const size_t h_out,
|
||||
const size_t w_out,
|
||||
const size_t h_k,
|
||||
const size_t w_k,
|
||||
const size_t stride,
|
||||
const size_t padding,
|
||||
const size_t dilation,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
T *dst
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// dst: (b_size, h_out, w_out, c_in, h_k, w_k)
|
||||
// src: (b_size, c_in, h_in, w_in)
|
||||
if (dst_i >= dst_numel) {
|
||||
return;
|
||||
}
|
||||
const size_t *src_dims = info;
|
||||
const size_t *src_s = info + 4;
|
||||
const size_t b_in = src_dims[0];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t h_in = src_dims[2];
|
||||
const size_t w_in = src_dims[3];
|
||||
|
||||
const size_t dst_s4 = w_k;
|
||||
const size_t dst_s3 = h_k * dst_s4;
|
||||
const size_t dst_s2 = c_in * dst_s3;
|
||||
const size_t dst_s1 = w_out * dst_s2;
|
||||
const size_t dst_s0 = h_out * dst_s1;
|
||||
|
||||
size_t tmp_dst_i = dst_i;
|
||||
const size_t b_idx = tmp_dst_i / dst_s0;
|
||||
tmp_dst_i -= b_idx * dst_s0;
|
||||
const size_t h_idx = tmp_dst_i / dst_s1;
|
||||
tmp_dst_i -= h_idx * dst_s1;
|
||||
const size_t w_idx = tmp_dst_i / dst_s2;
|
||||
tmp_dst_i -= w_idx * dst_s2;
|
||||
const size_t c_idx = tmp_dst_i / dst_s3;
|
||||
tmp_dst_i -= c_idx * dst_s3;
|
||||
const size_t h_k_idx = tmp_dst_i / dst_s4;
|
||||
tmp_dst_i -= h_k_idx * dst_s4;
|
||||
const size_t w_k_idx = tmp_dst_i;
|
||||
size_t src_h_idx = h_idx * stride + h_k_idx * dilation;
|
||||
size_t src_w_idx = w_idx * stride + w_k_idx * dilation;
|
||||
if (src_h_idx < padding || src_h_idx >= h_in + padding) {
|
||||
dst[dst_i] = static_cast<T>(0);
|
||||
}
|
||||
else if (src_w_idx < padding || src_w_idx >= w_in + padding) {
|
||||
dst[dst_i] = static_cast<T>(0);
|
||||
}
|
||||
else {
|
||||
src_h_idx -= padding;
|
||||
src_w_idx -= padding;
|
||||
const size_t src_i =
|
||||
b_idx * src_s[0]
|
||||
+ c_idx * src_s[1]
|
||||
+ src_h_idx * src_s[2]
|
||||
+ src_w_idx * src_s[3];
|
||||
dst[dst_i] = src[src_i];
|
||||
}
|
||||
}
|
||||
|
||||
// Naive implementation of conv2d.
|
||||
template <typename T, typename A>
|
||||
__device__ void conv2d(
|
||||
@ -363,6 +428,23 @@ extern "C" __global__ void FN_NAME( \
|
||||
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define IM2COL_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t dst_numel, \
|
||||
const size_t h_out, \
|
||||
const size_t w_out, \
|
||||
const size_t h_k, \
|
||||
const size_t w_k, \
|
||||
const size_t stride, \
|
||||
const size_t padding, \
|
||||
const size_t dilation, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
im2col<TYPENAME>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, info, src, dst); \
|
||||
} \
|
||||
|
||||
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t src_numel, \
|
||||
@ -428,6 +510,7 @@ CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16)
|
||||
AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
|
||||
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
||||
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
|
||||
IM2COL_OP(__nv_bfloat16, im2col_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
@ -437,6 +520,7 @@ CONVT2D_OP(__half, float, conv_transpose2d_f16)
|
||||
AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
|
||||
MAX_POOL2D_OP(__half, max_pool2d_f16)
|
||||
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
|
||||
IM2COL_OP(__half, im2col_f16)
|
||||
#endif
|
||||
|
||||
CONV1D_OP(float, float, conv1d_f32)
|
||||
@ -468,3 +552,8 @@ UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32)
|
||||
UPSAMPLE_NEAREST2D_OP(double, upsample_nearest2d_f64)
|
||||
UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8)
|
||||
UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32)
|
||||
|
||||
IM2COL_OP(float, im2col_f32)
|
||||
IM2COL_OP(double, im2col_f64)
|
||||
IM2COL_OP(uint8_t, im2col_u8)
|
||||
IM2COL_OP(uint32_t, im2col_u32)
|
||||
|
Reference in New Issue
Block a user