Add 1d upsampling. (#839)

* Add 1d upsampling.

* Add the interpolate functions.
This commit is contained in:
Laurent Mazare
2023-09-13 17:50:39 +02:00
committed by GitHub
parent 31ab2ddaeb
commit 9a465e1b26
8 changed files with 86 additions and 2 deletions

View File

@ -854,12 +854,30 @@ impl Tensor {
self.maximum(min)?.minimum(max)
}
/// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the
/// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element.
///
/// The input tensor should have three dimensions, `(batch, channels, l)`, the returned
/// tensor also has three dimensions, `(batch, channels, target_size)`.
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
let (n, c, _l) = self.dims3()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
let storage = self
.storage()
.upsample_nearest1d(self.layout(), target_size)?;
Ok(from_storage(storage, (n, c, target_size), op, false))
}
/// Alias for `interpolate1d`.
pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
self.interpolate1d(target_size)
}
/// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the
/// nearest element.
///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
let (n, c, _h, _w) = self.dims4()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
let storage = self
@ -868,6 +886,11 @@ impl Tensor {
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
}
/// Alias for `interpolate2d`.
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
self.interpolate2d(target_h, target_w)
}
/// 2D average pooling over an input tensor with multiple channels.
///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned