Add grads for interpolate1d (#1742)

* add backprop for interpolate1d

* fix clippy lint

* correct fix clippy lint
This commit is contained in:
Kirpal Grewal
2024-02-22 07:44:01 +00:00
committed by GitHub
parent 45d5322d62
commit 8013b50829
4 changed files with 51 additions and 6 deletions

View File

@ -1015,7 +1015,7 @@ impl Tensor {
/// tensor also has three dimensions, `(batch, channels, target_size)`.
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
let (n, c, _l) = self.dims3()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
let storage = self
.storage()
.upsample_nearest1d(self.layout(), target_size)?;