mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add 1d upsampling. (#839)
* Add 1d upsampling. * Add the interpolate functions.
This commit is contained in:
@ -854,12 +854,30 @@ impl Tensor {
|
||||
self.maximum(min)?.minimum(max)
|
||||
}
|
||||
|
||||
/// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the
|
||||
/// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element.
|
||||
///
|
||||
/// The input tensor should have three dimensions, `(batch, channels, l)`, the returned
|
||||
/// tensor also has three dimensions, `(batch, channels, target_size)`.
|
||||
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
|
||||
let (n, c, _l) = self.dims3()?;
|
||||
let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
|
||||
let storage = self
|
||||
.storage()
|
||||
.upsample_nearest1d(self.layout(), target_size)?;
|
||||
Ok(from_storage(storage, (n, c, target_size), op, false))
|
||||
}
|
||||
|
||||
/// Alias for `interpolate1d`.
|
||||
pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
|
||||
self.interpolate1d(target_size)
|
||||
}
|
||||
|
||||
/// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the
|
||||
/// nearest element.
|
||||
///
|
||||
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
|
||||
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
|
||||
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
||||
pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
||||
let (n, c, _h, _w) = self.dims4()?;
|
||||
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
|
||||
let storage = self
|
||||
@ -868,6 +886,11 @@ impl Tensor {
|
||||
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
|
||||
}
|
||||
|
||||
/// Alias for `interpolate2d`.
|
||||
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
||||
self.interpolate2d(target_h, target_w)
|
||||
}
|
||||
|
||||
/// 2D average pooling over an input tensor with multiple channels.
|
||||
///
|
||||
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
|
||||
|
Reference in New Issue
Block a user