diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 35619015..2a1db58a 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -113,7 +113,7 @@ impl Tensor { | Op::Unary(_node, UnaryOp::Floor) | Op::Unary(_node, UnaryOp::Round) => nodes, Op::Reshape(node) - | Op::UpsampleNearest1D(node) + | Op::UpsampleNearest1D { arg: node, .. } | Op::UpsampleNearest2D { arg: node, .. } | Op::AvgPool2D { arg: node, .. } | Op::MaxPool2D { arg: node, .. } @@ -348,9 +348,18 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad_arg)?; } - Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported { - op: "upsample-nearest1d", - })?, + Op::UpsampleNearest1D { arg, target_size } => { + let (_n, c, size) = arg.dims3()?; + if target_size % size != 0 { + crate::bail!("backward not supported for non integer upscaling factors") + } + let scale = target_size / size; + + let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?; + let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = conv_sum; + } Op::UpsampleNearest2D { arg, target_h, diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index d920485c..022b4fc3 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -132,7 +132,10 @@ pub enum Op { stride: (usize, usize), }, - UpsampleNearest1D(Tensor), + UpsampleNearest1D { + arg: Tensor, + target_size: usize, + }, UpsampleNearest2D { arg: Tensor, target_h: usize, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index a1aa9338..0e2c3e8f 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1015,7 +1015,7 @@ impl Tensor { /// tensor also has three dimensions, `(batch, channels, target_size)`. pub fn interpolate1d(&self, target_size: usize) -> Result { let (n, c, _l) = self.dims3()?; - let op = BackpropOp::new1(self, Op::UpsampleNearest1D); + let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size }); let storage = self .storage() .upsample_nearest1d(self.layout(), target_size)?; diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 76987635..4fbb21a7 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -283,6 +283,39 @@ fn unary_grad(device: &Device) -> Result<()> { [1.0881, 0.9277, 1.0527, 0.5747], ); + let x = Var::new(&[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]], device)?; + let y = x.interpolate1d(12)?.reshape(36)?; + + println!("y: {}", y.unsqueeze(1)?); + #[rustfmt::skip] + let z = Tensor::new( + &[ + 1_f32, 02., 03., 04., + 05., 06., 07., 08., + 09., 10., 11., 12., + 13., 14., 15., 16., + 17., 18., 19., 20., + 21., 22., 23., 24., + 25., 26., 27., 28., + 29., 30., 31., 32., + 33., 34., 35., 36., + ], + device, + )?; + + let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + + let grads = loss.backward()?; + + let grad_x = grads.get(&x).context("no grad for x")?; + + println!("grad: {grad_x}"); + + assert_eq!( + test_utils::to_vec3_round(grad_x, 4)?, + [[[10_f32, 26., 42.], [58., 74., 90.], [106., 122., 138.]]] + ); + // manually checked: see comments let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?; let y = x.interpolate2d(6, 6)?.reshape(36)?;