Add 1d upsampling. (#839)

* Add 1d upsampling.

* Add the interpolate functions.
This commit is contained in:
Laurent Mazare
2023-09-13 17:50:39 +02:00
committed by GitHub
parent 31ab2ddaeb
commit 9a465e1b26
8 changed files with 86 additions and 2 deletions

View File

@ -91,6 +91,7 @@ impl Tensor {
}
}
Op::Reshape(node)
| Op::UpsampleNearest1D(node)
| Op::UpsampleNearest2D(node)
| Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. }
@ -262,6 +263,9 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
}
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest1d",
})?,
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d",
})?,