Backprop for conv2d. (#638)

* Start adding backprop for conv2d.

* Backprop for conv2d.

* Bugfix + start adding a conv2d test.

* Conv2d backprop testing.

* More conv fixes.
This commit is contained in:
Laurent Mazare
2023-08-28 16:08:55 +01:00
committed by GitHub
parent 9137c63175
commit b292047882
4 changed files with 106 additions and 13 deletions

View File

@ -192,7 +192,30 @@ impl Tensor {
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
Op::Conv2D {
arg,
kernel,
padding,
stride,
} => {
// The output height for conv_transpose2d is:
// (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
let grad_h = grad.dim(2)?;
let k_h = kernel.dim(2)?;
let out_size = (grad_h - 1) * stride + (k_h - 1) + 1 - 2 * padding;
let out_padding = arg.dim(2)? - out_size;
let grad_arg =
grad.conv_transpose2d(kernel, *padding, out_padding, *stride)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = arg
.transpose(0, 1)?
.conv2d(&grad.transpose(0, 1)?, *padding, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose2d",
})?,