diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 7488d939..155f49c5 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -238,6 +238,13 @@ impl Tensor { .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)? .transpose(0, 1)?; let sum_grad = grads.or_insert(kernel)?; + let (_, _, k0, k1) = kernel.dims4()?; + let (_, _, g_k0, g_k1) = grad_kernel.dims4()?; + let grad_kernel = if g_k0 != k0 || g_k1 != k1 { + grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)? + } else { + grad_kernel + }; *sum_grad = sum_grad.add(&grad_kernel)?; } Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported { diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index 937ddf67..e7fdf138 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -479,6 +479,71 @@ fn conv2d_grad(dev: &Device) -> Result<()> { ] ] ); + + // Replicate the issue from https://github.com/huggingface/candle/issues/1212 + let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?; + let loss = res.sqr()?.sum_all()?; + assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32); + let grads = loss.backward()?; + let grad_t = grads.get(&t).unwrap(); + let grad_w = grads.get(&w).unwrap(); + assert_eq!(grad_t.dims(), [1, 4, 5, 5]); + assert_eq!(grad_w.dims(), [2, 4, 3, 3]); + assert_eq!( + test_utils::to_vec3_round(&grad_t.i(0)?, 2)?, + [ + [ + [9.29, -7.03, 7.87, 0.0, 0.0], + [-1.8, -7.82, 5.9, 0.0, 0.0], + [-3.12, 4.49, 5.52, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0] + ], + [ + [21.73, 3.39, 4.77, 0.0, 0.0], + [8.25, 3.73, 27.61, 0.0, 0.0], + [-20.55, -5.61, -2.77, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0] + ], + [ + [-8.98, 9.91, -7.15, 0.0, 0.0], + [4.93, -0.33, 4.56, 0.0, 0.0], + [-6.7, -5.76, -8.05, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0] + ], + [ + [23.54, 6.98, -10.0, 0.0, 0.0], + [9.65, 6.18, 18.72, 0.0, 0.0], + [3.29, -5.27, 0.79, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0] + ] + ] + ); + assert_eq!( + test_utils::to_vec3_round(&grad_w.i(0)?, 2)?, + [ + [ + [-3.47, 7.44, 0.66], + [12.89, -3.4, -9.29], + [-14.16, -0.83, 7.14] + ], + [ + [-3.23, 5.37, -3.02], + [-2.12, -11.24, 1.94], + [6.97, 7.2, 2.99] + ], + [ + [-4.04, -3.31, 4.87], + [-6.68, -5.68, 1.73], + [-5.54, 4.32, 0.52] + ], + [[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]] + ] + ); + Ok(()) }