Add grads for interpolate1d (#1742)

* add backprop for interpolate1d

* fix clippy lint

* correct fix clippy lint
This commit is contained in:
Kirpal Grewal
2024-02-22 07:44:01 +00:00
committed by GitHub
parent 45d5322d62
commit 8013b50829
4 changed files with 51 additions and 6 deletions

View File

@ -113,7 +113,7 @@ impl Tensor {
| Op::Unary(_node, UnaryOp::Floor)
| Op::Unary(_node, UnaryOp::Round) => nodes,
Op::Reshape(node)
| Op::UpsampleNearest1D(node)
| Op::UpsampleNearest1D { arg: node, .. }
| Op::UpsampleNearest2D { arg: node, .. }
| Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. }
@ -348,9 +348,18 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
}
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest1d",
})?,
Op::UpsampleNearest1D { arg, target_size } => {
let (_n, c, size) = arg.dims3()?;
if target_size % size != 0 {
crate::bail!("backward not supported for non integer upscaling factors")
}
let scale = target_size / size;
let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = conv_sum;
}
Op::UpsampleNearest2D {
arg,
target_h,

View File

@ -132,7 +132,10 @@ pub enum Op {
stride: (usize, usize),
},
UpsampleNearest1D(Tensor),
UpsampleNearest1D {
arg: Tensor,
target_size: usize,
},
UpsampleNearest2D {
arg: Tensor,
target_h: usize,

View File

@ -1015,7 +1015,7 @@ impl Tensor {
/// tensor also has three dimensions, `(batch, channels, target_size)`.
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
let (n, c, _l) = self.dims3()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
let storage = self
.storage()
.upsample_nearest1d(self.layout(), target_size)?;

View File

@ -283,6 +283,39 @@ fn unary_grad(device: &Device) -> Result<()> {
[1.0881, 0.9277, 1.0527, 0.5747],
);
let x = Var::new(&[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]], device)?;
let y = x.interpolate1d(12)?.reshape(36)?;
println!("y: {}", y.unsqueeze(1)?);
#[rustfmt::skip]
let z = Tensor::new(
&[
1_f32, 02., 03., 04.,
05., 06., 07., 08.,
09., 10., 11., 12.,
13., 14., 15., 16.,
17., 18., 19., 20.,
21., 22., 23., 24.,
25., 26., 27., 28.,
29., 30., 31., 32.,
33., 34., 35., 36.,
],
device,
)?;
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
let grads = loss.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
println!("grad: {grad_x}");
assert_eq!(
test_utils::to_vec3_round(grad_x, 4)?,
[[[10_f32, 26., 42.], [58., 74., 90.], [106., 122., 138.]]]
);
// manually checked: see comments
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
let y = x.interpolate2d(6, 6)?.reshape(36)?;