diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 307b56dc..c3f8aa3c 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -38,6 +38,7 @@ pub trait BackendStorage: Sized { ) -> Result; fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result; + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result; fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result; fn scatter_add( diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 401a2c0e..a04ed9a0 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -672,6 +672,43 @@ impl Map1 for AvgPool2D { } } +struct UpsampleNearest2D(usize, usize); + +impl Map1 for UpsampleNearest2D { + fn f(&self, src: &[T], layout: &Layout) -> Result> { + // TODO: Specialized implementation for the case 2*h, 2*w? + let (dst_h, dst_w) = (self.0, self.1); + let (b_sz, c, src_h, src_w) = layout.shape().dims4()?; + let stride = layout.stride(); + let (stride_h, stride_w) = (stride[2], stride[3]); + let src_index = layout.start_offset(); + let scale_h = src_h as f64 / dst_h as f64; + let scale_w = src_w as f64 / dst_w as f64; + let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w]; + let src_h_idxs = (0..src_h) + .map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize)) + .collect::>(); + let src_w_idxs = (0..src_w) + .map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize)) + .collect::>(); + for b_idx in 0..b_sz { + let dst = &mut dst[b_idx * c * dst_h * dst_w..]; + let src_index = src_index + b_idx * stride[0]; + for c_idx in 0..c { + let dst = &mut dst[c_idx * dst_h * dst_w..]; + let src_index = src_index + c_idx * stride[1]; + for (h_idx, src_h_idx) in src_h_idxs.iter().enumerate() { + for (w_idx, src_w_idx) in src_w_idxs.iter().enumerate() { + let src_index = src_index + src_h_idx * stride_h + src_w_idx * stride_w; + dst[h_idx * dst_w + w_idx] = src[src_index] + } + } + } + } + Ok(dst) + } +} + struct Gather<'a, I: IntDType> { ids: &'a [I], ids_l: &'a Layout, @@ -1577,6 +1614,10 @@ impl BackendStorage for CpuStorage { AvgPool2D(kernel_size, stride).map(self, layout) } + fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result { + UpsampleNearest2D(h, w).map(self, layout) + } + fn elu(&self, layout: &Layout, alpha: f64) -> Result { // TODO: Have some generic map for functions that apply on num_traits::Float elements. match self { diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index e71ecfce..3c37373a 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1385,6 +1385,10 @@ impl BackendStorage for CudaStorage { todo!() } + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result { + todo!() + } + fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result { let device = self.device().clone(); let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?; diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 2d5f955c..99cb7c4e 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -123,6 +123,10 @@ impl crate::backend::BackendStorage for CudaStorage { fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { Err(Error::NotCompiledWithCudaSupport) } + + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } } impl crate::backend::BackendDevice for CudaDevice { diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 47df689c..b4fa02e8 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -284,13 +284,17 @@ impl Storage { } } - pub(crate) fn upsample_nearest2d( - &self, - _layout: &Layout, - _h: usize, - _w: usize, - ) -> Result { - todo!() + pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result { + match self { + Storage::Cpu(storage) => { + let storage = storage.upsample_nearest2d(layout, h, w)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.upsample_nearest2d(layout, h, w)?; + Ok(Self::Cuda(storage)) + } + } } pub(crate) fn where_cond(