diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 22c28ac4..4366f3b6 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -192,7 +192,30 @@ impl Tensor { *f_sum_grad = f_sum_grad.add(&f_grad)?; } Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?, - Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?, + Op::Conv2D { + arg, + kernel, + padding, + stride, + } => { + // The output height for conv_transpose2d is: + // (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1 + let grad_h = grad.dim(2)?; + let k_h = kernel.dim(2)?; + let out_size = (grad_h - 1) * stride + (k_h - 1) + 1 - 2 * padding; + let out_padding = arg.dim(2)? - out_size; + let grad_arg = + grad.conv_transpose2d(kernel, *padding, out_padding, *stride)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad_arg)?; + + let grad_kernel = arg + .transpose(0, 1)? + .conv2d(&grad.transpose(0, 1)?, *padding, *stride, 1)? + .transpose(0, 1)?; + let sum_grad = grads.or_insert(kernel)?; + *sum_grad = sum_grad.add(&grad_kernel)?; + } Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported { op: "conv-transpose2d", })?, diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index 3455247b..d9e0a9ab 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -71,18 +71,14 @@ pub struct ParamsConvTranspose2D { impl ParamsConvTranspose2D { pub(crate) fn out_h(&self) -> usize { let dilation = 1; - (self.i_h - 1) * self.stride - 2 * self.padding - + dilation * (self.k_h - 1) - + self.output_padding - + 1 + (self.i_h - 1) * self.stride + dilation * (self.k_h - 1) + self.output_padding + 1 + - 2 * self.padding } pub(crate) fn out_w(&self) -> usize { let dilation = 1; - (self.i_w - 1) * self.stride - 2 * self.padding - + dilation * (self.k_w - 1) - + self.output_padding - + 1 + (self.i_w - 1) * self.stride + dilation * (self.k_w - 1) + self.output_padding + 1 + - 2 * self.padding } pub(crate) fn out_dims(&self) -> Vec { diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 0b19904b..bc4470d6 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1204,7 +1204,7 @@ impl<'a> Map2 for ConvTranspose2D<'a> { let inp_x = out_x * p.stride as i32 - p.padding as i32; let inp_y = out_y * p.stride as i32 - p.padding as i32; for k_y in 0..p.k_h as i32 { - for k_x in 0..p.k_h as i32 { + for k_x in 0..p.k_w as i32 { let k_index = k_y as usize * k_s2 + k_x as usize * k_s3; let inp_y = inp_y + k_y; let inp_x = inp_x + k_x; @@ -1215,9 +1215,11 @@ impl<'a> Map2 for ConvTranspose2D<'a> { let inp_y = inp_y as usize; if inp_x < p.i_w && inp_y < p.i_h { let inp_index = b_idx * inp_s0 + inp_y * inp_s2 + inp_x * inp_s3; - let dst_index = b_idx * dst_s0 + inp_y * dst_s2 + inp_x * dst_s3; - for c_out in 0..k_s0 { - for c_in in 0..k_s1 { + let dst_index = b_idx * dst_s0 + + out_y as usize * dst_s2 + + out_x as usize * dst_s3; + for c_out in 0..p.c_out { + for c_in in 0..p.c_in { let k_index = k_index + c_out * k_s1 + c_in * k_s0; let dst_index = dst_index + c_out * dst_s1; let inp_index = inp_index + c_in * inp_s1; diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index 310d2462..2bd4e9df 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -243,6 +243,78 @@ fn conv2d_non_square(dev: &Device) -> Result<()> { Ok(()) } +#[test] +fn conv2d_grad() -> Result<()> { + use candle_core::Var; + let dev = &Device::Cpu; + let t = Var::from_slice( + &[ + 0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616, + 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395, + 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836, + 0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123, + 1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586, + 0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049, + 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712, + 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790, + -0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006, + -0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085, + ], + (1, 4, 5, 5), + dev, + )?; + let w = Var::from_slice( + &[ + -0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273, + -2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514, + -0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027, + 0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667, + 0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679, + -0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646, + 1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860, + 0.5583, 0.4623, 0.6026, + ], + (2, 4, 3, 3), + dev, + )?; + let res = t.conv2d(&w, 0, 1, 1)?; + let loss = res.sqr()?.sum_all()?; + assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 741.12f32); + let grads = loss.backward()?; + let grad_t = grads.get(&t).unwrap(); + let grad_w = grads.get(&w).unwrap(); + assert_eq!(grad_t.dims(), [1, 4, 5, 5]); + assert_eq!(grad_w.dims(), [2, 4, 3, 3]); + assert_eq!( + test_utils::to_vec1_round(&grad_t.flatten_all()?, 4)?, + // THIS IS WRONG AT THE MOMENT + [ + 1.7442, -10.1747, -9.9426, 0.0, 0.0, -1.7046, -21.2248, 30.8435, 0.0, 0.0, -18.713, + -1.0547, -7.8746, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 169.3047, + 46.0812, 40.6937, 0.0, 0.0, -85.8156, 4.537, 53.2871, 0.0, 0.0, -59.632, -35.9725, + -7.1689, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 48.823, 8.9794, + 42.3011, 0.0, 0.0, -58.9268, 32.907, -50.6863, 0.0, 0.0, -0.9706, -3.9175, -4.2594, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 72.8229, 25.8492, 8.9871, + 0.0, 0.0, -136.2584, 40.1739, 88.9583, 0.0, 0.0, -53.465, -40.7102, -24.9406, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 + ] + ); + assert_eq!( + test_utils::to_vec1_round(&grad_w.flatten_all()?, 4)?, + [ + -28.9232, -22.8833, -141.2296, 73.3462, 61.074, 47.8125, -20.0013, -73.7086, -41.8217, + -13.5919, 21.501, 28.7179, 28.5683, -46.8486, -90.1874, 143.6107, 16.6764, 7.4259, + 18.8794, -90.8122, -20.2865, 54.7909, 82.6287, 22.943, 77.8084, -16.3928, -13.1977, + 9.3442, -40.3869, -26.6153, 5.3344, -60.9081, 9.0869, -59.368, 7.081, 58.6391, 5.5476, + 20.5152, 2.4985, -17.2466, -6.802, 22.2146, 30.1511, -7.5179, -37.4588, 5.6654, + 22.5832, 9.0316, 47.0547, 17.6123, 37.3121, -98.1295, -14.6141, -4.7958, -6.3597, + 44.6949, 23.3418, 8.3728, -13.52, 80.0522, -34.2403, -16.3648, -12.3139, 1.9195, + -33.6244, -14.102, -49.2305, -7.3853, 11.4995, -9.9826, 9.6588, 29.6042 + ] + ); + Ok(()) +} + test_device!(conv1d, conv1d_cpu, conv1d_gpu); test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu); test_device!(conv2d, conv2d_cpu, conv2d_gpu);