diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 67a08714..03a07434 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -57,6 +57,7 @@ pub trait BackendStorage: Sized { fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result; fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result; + fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result; fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result; fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result; diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index d2099df7..b930a9f4 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -91,6 +91,7 @@ impl Tensor { } } Op::Reshape(node) + | Op::UpsampleNearest1D(node) | Op::UpsampleNearest2D(node) | Op::AvgPool2D { arg: node, .. } | Op::MaxPool2D { arg: node, .. } @@ -262,6 +263,9 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad_arg)?; } + Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported { + op: "upsample-nearest1d", + })?, Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { op: "upsample-nearest2d", })?, diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 544ce32d..4e808b34 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -727,6 +727,36 @@ impl Map1 for MaxPool2D { } } +struct UpsampleNearest1D(usize); + +impl Map1 for UpsampleNearest1D { + fn f(&self, src: &[T], layout: &Layout) -> Result> { + // TODO: Specialized implementation for the case 2*sz? + let dst_sz = self.0; + let (b_sz, c, src_sz) = layout.shape().dims3()?; + let stride = layout.stride(); + let stride_sz = stride[2]; + let src_index = layout.start_offset(); + let scale_sz = src_sz as f64 / dst_sz as f64; + let mut dst = vec![T::zero(); b_sz * c * dst_sz]; + let src_idxs = (0..dst_sz) + .map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize)) + .collect::>(); + for b_idx in 0..b_sz { + let dst = &mut dst[b_idx * c * dst_sz..]; + let src_index = src_index + b_idx * stride[0]; + for c_idx in 0..c { + let dst = &mut dst[c_idx * dst_sz..]; + let src_index = src_index + c_idx * stride[1]; + for (idx, src_idx) in src_idxs.iter().enumerate() { + dst[idx] = src[src_index + src_idx * stride_sz] + } + } + } + Ok(dst) + } +} + struct UpsampleNearest2D(usize, usize); impl Map1 for UpsampleNearest2D { @@ -2137,6 +2167,10 @@ impl BackendStorage for CpuStorage { MaxPool2D(kernel_size, stride).map(self, layout) } + fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result { + UpsampleNearest1D(sz).map(self, layout) + } + fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result { UpsampleNearest2D(h, w).map(self, layout) } diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 4ec41b87..00fd1d04 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1954,6 +1954,10 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } + fn upsample_nearest1d(&self, _: &Layout, _out_sz: usize) -> Result { + crate::bail!("upsample-nearest1d is not supported on cuda") + } + fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result { let device = self.device().clone(); let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?; diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 6c896653..5cc9c6d8 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -152,6 +152,10 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 9382b217..7940739c 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -116,6 +116,7 @@ pub enum Op { stride: (usize, usize), }, + UpsampleNearest1D(Tensor), UpsampleNearest2D(Tensor), Cat(Vec, usize), diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 8bd14ea9..9bd1fed6 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -369,6 +369,19 @@ impl Storage { } } + pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result { + match self { + Storage::Cpu(storage) => { + let storage = storage.upsample_nearest1d(layout, sz)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.upsample_nearest1d(layout, sz)?; + Ok(Self::Cuda(storage)) + } + } + } + pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result { match self { Storage::Cpu(storage) => { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 59a23c39..4388bf77 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -854,12 +854,30 @@ impl Tensor { self.maximum(min)?.minimum(max) } - /// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the + /// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element. + /// + /// The input tensor should have three dimensions, `(batch, channels, l)`, the returned + /// tensor also has three dimensions, `(batch, channels, target_size)`. + pub fn interpolate1d(&self, target_size: usize) -> Result { + let (n, c, _l) = self.dims3()?; + let op = BackpropOp::new1(self, Op::UpsampleNearest1D); + let storage = self + .storage() + .upsample_nearest1d(self.layout(), target_size)?; + Ok(from_storage(storage, (n, c, target_size), op, false)) + } + + /// Alias for `interpolate1d`. + pub fn upsample_nearest1d(&self, target_size: usize) -> Result { + self.interpolate1d(target_size) + } + + /// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the /// nearest element. /// /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned /// tensor also has four dimensions, `(batch, channels, target_h, target_w)`. - pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result { + pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result { let (n, c, _h, _w) = self.dims4()?; let op = BackpropOp::new1(self, Op::UpsampleNearest2D); let storage = self @@ -868,6 +886,11 @@ impl Tensor { Ok(from_storage(storage, (n, c, target_h, target_w), op, false)) } + /// Alias for `interpolate2d`. + pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result { + self.interpolate2d(target_h, target_w) + } + /// 2D average pooling over an input tensor with multiple channels. /// /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned