mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Support dilation in conv-transpose2d. (#671)
This commit is contained in:
@ -155,15 +155,15 @@ __device__ void conv_transpose2d(
|
||||
const size_t src_idx0 = b_idx * src_s[0];
|
||||
A d = 0;
|
||||
for (int k_x = 0; k_x < (int)w_k; ++k_x) {
|
||||
// let out_x = inp_x * p.stride + k_x - p.padding;
|
||||
int inp_x_stride = (int)(out_x + padding) - k_x;
|
||||
// let out_x = inp_x * p.stride + k_x * p.dilation - p.padding;
|
||||
int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
|
||||
if (inp_x_stride < 0 || inp_x_stride % stride) {
|
||||
continue;
|
||||
}
|
||||
int inp_x = inp_x_stride / stride;
|
||||
if (inp_x >= w_in) continue;
|
||||
for (int k_y = 0; k_y < (int)h_k; ++k_y) {
|
||||
int inp_y_stride = (int)(out_y + padding) - k_y;
|
||||
int inp_y_stride = (int)(out_y + padding) - k_y * dilation;
|
||||
if (inp_y_stride < 0 || inp_y_stride % stride) {
|
||||
continue;
|
||||
}
|
||||
|
Reference in New Issue
Block a user