Fix for backprop in ConvTranspose2D with stride of 2 (#2337)

* Add gradient test for conv_transpose2d with stride of 2.

* Swap dilation and stride in ConvTranspose2D backpropagation.

Without this, a shape mismatch occurs with a stride of 2 and dilation of 1.

* Add further tests of the ConvTranspose2D gradient.

Values calculated with torch, minor numerical errors adjusted and commented.
This commit is contained in:
Ivor Wanders
2024-07-17 13:22:23 -04:00
committed by GitHub
parent 6a4741bbf9
commit f25173d68b
2 changed files with 99 additions and 2 deletions

View File

@ -320,13 +320,13 @@ impl Tensor {
dilation,
output_padding: _output_padding,
} => {
let grad_arg = grad.conv2d(kernel, *padding, *dilation, *stride, 1)?;
let grad_arg = grad.conv2d(kernel, *padding, *stride, *dilation, 1)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = grad
.transpose(0, 1)?
.conv2d(&arg.transpose(0, 1)?, *padding, *stride, *dilation, 1)?
.conv2d(&arg.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0, k1) = kernel.dims4()?;