diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index a8e5ac52..4c31ca6f 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -46,6 +46,7 @@ pub trait BackendStorage: Sized { ) -> Result; fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result; + fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result; fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result; fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result; diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 0eab508e..2a60fe30 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -88,6 +88,7 @@ impl Tensor { Op::Reshape(node) | Op::UpsampleNearest2D(node) | Op::AvgPool2D { arg: node, .. } + | Op::MaxPool2D { arg: node, .. } | Op::Copy(node) | Op::Broadcast(node) | Op::Cmp(node, _) @@ -172,6 +173,7 @@ impl Tensor { Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?, Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?, Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?, + Op::MaxPool2D { .. } => Err(Error::BackwardNotSupported { op: "max-pool2d" })?, Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { op: "upsample-nearest2d", })?, diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 1f94a9bc..54f3f65b 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -674,6 +674,48 @@ impl Map1 for AvgPool2D { } } +struct MaxPool2D((usize, usize), (usize, usize)); + +impl Map1 for MaxPool2D { + fn f(&self, src: &[T], layout: &Layout) -> Result> { + // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html + let (k_h, k_w) = self.0; + let (s_h, s_w) = self.1; + let (b_sz, c, h, w) = layout.shape().dims4()?; + let stride = layout.stride(); + let (stride_h, stride_w) = (stride[2], stride[3]); + let h_out = (h - k_h) / s_h + 1; + let w_out = (w - k_w) / s_w + 1; + let src_index = layout.start_offset(); + let mut dst = vec![T::zero(); b_sz * c * h_out * w_out]; + for b_idx in 0..b_sz { + let dst = &mut dst[b_idx * c * h_out * w_out..]; + let src_index = src_index + b_idx * stride[0]; + for c_idx in 0..c { + let dst = &mut dst[c_idx * h_out * w_out..]; + let src_index = src_index + c_idx * stride[1]; + for h_idx in 0..h_out { + for w_idx in 0..w_out { + let mut largest = + src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w]; + for m in 0..k_h { + for n in 0..k_w { + let m = s_h * h_idx + m; + let n = s_w * w_idx + n; + if largest < src[src_index + m * stride_h + n * stride_w] { + largest = src[src_index + m * stride_h + n * stride_w] + } + } + } + dst[h_idx * w_out + w_idx] = largest; + } + } + } + } + Ok(dst) + } +} + struct UpsampleNearest2D(usize, usize); impl Map1 for UpsampleNearest2D { @@ -1664,6 +1706,15 @@ impl BackendStorage for CpuStorage { AvgPool2D(kernel_size, stride).map(self, layout) } + fn max_pool2d( + &self, + layout: &Layout, + kernel_size: (usize, usize), + stride: (usize, usize), + ) -> Result { + MaxPool2D(kernel_size, stride).map(self, layout) + } + fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result { UpsampleNearest2D(h, w).map(self, layout) } diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 727ea073..e51cc05d 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1395,6 +1395,10 @@ impl BackendStorage for CudaStorage { todo!() } + fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + todo!() + } + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result { todo!() } diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index ae4dd09f..870a87cd 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -134,6 +134,10 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index aea8b733..f99d8adc 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -93,6 +93,13 @@ pub enum Op { kernel_size: (usize, usize), stride: (usize, usize), }, + + MaxPool2D { + arg: Tensor, + kernel_size: (usize, usize), + stride: (usize, usize), + }, + UpsampleNearest2D(Tensor), Cat(Vec, usize), diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 3ed38e6a..791b65dd 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -311,6 +311,24 @@ impl Storage { } } + pub(crate) fn max_pool2d( + &self, + layout: &Layout, + kernel_size: (usize, usize), + stride: (usize, usize), + ) -> Result { + match self { + Storage::Cpu(storage) => { + let storage = storage.max_pool2d(layout, kernel_size, stride)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.max_pool2d(layout, kernel_size, stride)?; + Ok(Self::Cuda(storage)) + } + } + } + pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result { match self { Storage::Cpu(storage) => { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index adba7376..c94c0390 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -872,6 +872,22 @@ impl Tensor { Ok(from_storage(storage, (n, c, h_out, w_out), op, false)) } + pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result { + let (n, c, h, w) = self.dims4()?; + // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d + let h_out = (h - kernel_size.0) / stride.0 + 1; + let w_out = (w - kernel_size.1) / stride.1 + 1; + let op = BackpropOp::new1(self, |arg| Op::MaxPool2D { + arg, + kernel_size, + stride, + }); + let storage = self + .storage() + .max_pool2d(self.layout(), kernel_size, stride)?; + Ok(from_storage(storage, (n, c, h_out, w_out), op, false)) + } + /// Returns the matrix-multiplication of the input tensor with the other provided tensor. /// /// # Arguments diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs index 7335e378..c8ddef97 100644 --- a/candle-core/tests/pool_tests.rs +++ b/candle-core/tests/pool_tests.rs @@ -14,6 +14,18 @@ fn avg_pool2d() -> anyhow::Result<()> { Ok(()) } +#[test] +fn max_pool2d() -> anyhow::Result<()> { + let data: Vec = vec![ + 1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1., + ]; + let t = Tensor::from_vec(data, (1, 1, 4, 4), &Device::Cpu)?; + + let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?; + assert_eq!(pool.to_vec2::()?, [[2f32, 3.], [5., 1.]]); + Ok(()) +} + /* This test corresponds to the following PyTorch script. import torch torch.manual_seed(4242)