Fix the conv2d gradient computation. (#1214)

This commit is contained in:
Laurent Mazare
2023-10-29 10:50:04 +01:00
committed by GitHub
parent 55bc3382cf
commit 46d6566c99
2 changed files with 72 additions and 0 deletions

View File

@ -238,6 +238,13 @@ impl Tensor {
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0, k1) = kernel.dims4()?;
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {