mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Dilated convolutions (#657)
* Add the dilation parameter. * Restore the basic optimizer example. * Dilation support in cudnn. * Use the dilation parameter in the cpu backend. * More dilation support. * No support for dilation in transposed convolutions. * Add dilation to a test. * Remove a print. * Helper function.
This commit is contained in:
@ -8,6 +8,7 @@ __device__ void conv1d(
|
||||
const size_t l_out,
|
||||
const size_t stride,
|
||||
const size_t padding,
|
||||
const size_t dilation,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
const T *kernel,
|
||||
@ -36,7 +37,7 @@ __device__ void conv1d(
|
||||
const size_t src_idx0 = b_idx * src_s[0];
|
||||
A d = 0;
|
||||
for (size_t offset = 0; offset < k_size; ++offset) {
|
||||
size_t src_l = stride * dst_l + offset;
|
||||
size_t src_l = (stride * dst_l + offset) * dilation;
|
||||
if (src_l < padding || src_l >= padding + l_in) {
|
||||
continue;
|
||||
}
|
||||
@ -58,6 +59,7 @@ __device__ void conv2d(
|
||||
const size_t h_out,
|
||||
const size_t stride,
|
||||
const size_t padding,
|
||||
const size_t dilation,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
const T *kernel,
|
||||
@ -90,13 +92,13 @@ __device__ void conv2d(
|
||||
const size_t src_idx0 = b_idx * src_s[0];
|
||||
A d = 0;
|
||||
for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
|
||||
size_t src_w = stride * dst_w + w_offset;
|
||||
size_t src_w = (stride * dst_w + w_offset) * dilation;
|
||||
if (src_w < padding || src_w >= w_in + padding) {
|
||||
continue;
|
||||
}
|
||||
src_w -= padding;
|
||||
for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
|
||||
size_t src_h = stride * dst_h + h_offset;
|
||||
size_t src_h = (stride * dst_h + h_offset) * dilation;
|
||||
if (src_h < padding || src_h >= h_in + padding) {
|
||||
continue;
|
||||
}
|
||||
@ -120,6 +122,7 @@ __device__ void conv_transpose2d(
|
||||
const size_t stride,
|
||||
const size_t padding,
|
||||
const size_t out_padding,
|
||||
const size_t dilation,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
const T *kernel,
|
||||
@ -335,12 +338,13 @@ extern "C" __global__ void FN_NAME( \
|
||||
const size_t num_dims, \
|
||||
const size_t stride, \
|
||||
const size_t padding, \
|
||||
const size_t dilation, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *kernel, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, padding, info, src, kernel, dst); \
|
||||
conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, padding, dilation, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define CONV2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
@ -350,12 +354,13 @@ extern "C" __global__ void FN_NAME( \
|
||||
const size_t h_out, \
|
||||
const size_t stride, \
|
||||
const size_t padding, \
|
||||
const size_t dilation, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *kernel, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, info, src, kernel, dst); \
|
||||
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
@ -366,12 +371,13 @@ extern "C" __global__ void FN_NAME( \
|
||||
const size_t stride, \
|
||||
const size_t padding, \
|
||||
const size_t out_padding, \
|
||||
const size_t dilation, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *kernel, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, info, src, kernel, dst); \
|
||||
conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define AVG_POOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
|
Reference in New Issue
Block a user