add max_pool2d (#371)

Co-authored-by: 赵理山 <ls@zhaolishandeMacBook-Air.local>
This commit is contained in:
LeeeSe
2023-08-10 01:05:26 +08:00
committed by GitHub
parent 1892bd139c
commit a5c5a893aa
9 changed files with 115 additions and 0 deletions

View File

@ -872,6 +872,22 @@ impl Tensor {
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
}
pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
let (n, c, h, w) = self.dims4()?;
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1;
let w_out = (w - kernel_size.1) / stride.1 + 1;
let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
arg,
kernel_size,
stride,
});
let storage = self
.storage()
.max_pool2d(self.layout(), kernel_size, stride)?;
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
}
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
///
/// # Arguments