diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index fc0c79a2..c152f31f 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -114,7 +114,7 @@ impl Tensor { | Op::Unary(_node, UnaryOp::Round) => nodes, Op::Reshape(node) | Op::UpsampleNearest1D(node) - | Op::UpsampleNearest2D(node) + | Op::UpsampleNearest2D { arg: node, .. } | Op::AvgPool2D { arg: node, .. } | Op::MaxPool2D { arg: node, .. } | Op::Copy(node) @@ -350,9 +350,27 @@ impl Tensor { Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported { op: "upsample-nearest1d", })?, - Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { - op: "upsample-nearest2d", - })?, + Op::UpsampleNearest2D { + arg, + target_h, + target_w, + } => { + let (_n, c, h, w) = arg.dims4()?; + if target_h % h != 0 || target_w % w != 0 { + crate::bail!("backward not supported for non integer upscaling factors") + } + let scale_h = target_h / h; + let scale_w = target_w / w; + + if scale_h != scale_w { + crate::bail!("backward not supported for non uniform upscaling factors") + }; + let kernel = + Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?; + let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = conv_sum; + } Op::SliceScatter0(lhs, rhs, start_rhs) => { let rhs_sum_grad = grads.or_insert(rhs)?; let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?; diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index fbb20f6c..868673e7 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -132,7 +132,11 @@ pub enum Op { }, UpsampleNearest1D(Tensor), - UpsampleNearest2D(Tensor), + UpsampleNearest2D { + arg: Tensor, + target_h: usize, + target_w: usize, + }, Cat(Vec, usize), diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 87323a84..4d9b0837 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -994,7 +994,11 @@ impl Tensor { /// tensor also has four dimensions, `(batch, channels, target_h, target_w)`. pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result { let (n, c, _h, _w) = self.dims4()?; - let op = BackpropOp::new1(self, Op::UpsampleNearest2D); + let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D { + arg, + target_h, + target_w, + }); let storage = self .storage() .upsample_nearest2d(self.layout(), target_h, target_w)?; diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 791532f2..16e7a82f 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -270,6 +270,166 @@ fn unary_grad(device: &Device) -> Result<()> { [0.7358, 2.0000, 0.2707, 1.0000] ); + // manually checked: see comments + let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?; + let y = x.interpolate2d(6, 6)?.reshape(36)?; + + #[rustfmt::skip] + let z = Tensor::new( + &[ + 1_f32, 02., 03., 04., 05., 06., + 07., 08., 09., 10., 11., 12., + 13., 14., 15., 16., 17., 18., + 19., 20., 21., 22., 23., 24., + 25., 26., 27., 28., 29., 30., + 31., 32., 33., 34., 35., 36., + ], + device, + )?; + // gradient should be + // row 1 + // 1+2+7+8 = 18 + // 3+4+9+10 = 26 + // 5+6+11+12 = 34 + // row 2 + // 13+14+19+20 = 66 + // 15+16+21+22 = 74 + // 17+18+23+24 = 82 + // row 3 + // 25+26+31+32 = 114 + // 27+28+33+34 = 122 + // 29+30+35+36 = 130 + let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + + let grads = loss.backward()?; + + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?, + [[18_f32, 26., 34.], [66., 74., 82.], [114., 122., 130.]] + ); + + // manually checked: see comments + let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?; + let y = x.interpolate2d(6, 6)?.reshape(36)?; + + #[rustfmt::skip] + let z = Tensor::new( + &[ + 1_f32, 02., 03., 04., 05., 06., + 07., 08., 09., 10., 11., 12., + 13., 14., 15., 16., 17., 18., + 19., 20., 21., 22., 23., 24., + 25., 26., 27., 28., 29., 30., + 31., 32., 33., 34., 35., 36., + ], + device, + )?; + // gradient should be + // row 1 + // 1+2+3+7+8+9+13+14+15 = 72 + // 4+5+6+10+11+12+16+17+18 = 99 + // row 2 + // 19+20+21+25+26+27+31+32+33 = 234 + // 22+23+24+28+29+30+34+35+36 = 243 + let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + + let grads = loss.backward()?; + + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?, + [[72_f32, 99.], [234., 261.]] + ); + + // manually checked: see comments + let x = Var::new(&[[[[1f32, 2.], [4., 5.]], [[6f32, 7.], [8., 9.]]]], device)?; + + let y = x.interpolate2d(4, 4)?.reshape(32)?; + + #[rustfmt::skip] + let z = Tensor::new( + &[ + 1_f32, 02., 03., 04., + 05., 06., 07., 08., + 09., 10., 11., 12., + 13., 14., 15., 16., + 17., 18., 19., 20., + 21., 22., 23., 24., + 25., 26., 27., 28., + 29., 30., 31., 32. + ], + device, + )?; + // gradient should be + // m1r1 + // 1+2+5+6=14 + // 3+4+7+8=22 + // m1r2 + // 9+10+13+14=46 + // 11+12+15+16=54 + // m2r1 + // 17+18+21+22=78 + // 19+20+23+24=86 + // m2r2 + // 25+26+29+30=110 + // 27+28+31+32=118 + let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + + let grads = loss.backward()?; + + let grad_x = grads.get(&x).context("no grad for x")?; + + assert_eq!( + test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?, + [[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]] + ); + + // manually checked: see comments + let x = Var::new( + &[[[[1f32, 2.], [4., 5.]]], [[[6f32, 7.], [8., 9.]]]], + device, + )?; + + let y = x.interpolate2d(4, 4)?.reshape(32)?; + + #[rustfmt::skip] + let z = Tensor::new( + &[ + 1_f32, 02., 03., 04., + 05., 06., 07., 08., + 09., 10., 11., 12., + 13., 14., 15., 16., + 17., 18., 19., 20., + 21., 22., 23., 24., + 25., 26., 27., 28., + 29., 30., 31., 32. + ], + device, + )?; + // gradient should be + // m1r1 + // 1+2+5+6=14 + // 3+4+7+8=22 + // m1r2 + // 9+10+13+14=46 + // 11+12+15+16=54 + // m2r1 + // 17+18+21+22=78 + // 19+20+23+24=86 + // m2r2 + // 25+26+29+30=110 + // 27+28+31+32=118 + let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + + let grads = loss.backward()?; + + let grad_x = grads.get(&x).context("no grad for x")?; + + assert_eq!( + test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?, + [[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]] + ); Ok(()) }