mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
add max_pool2d (#371)
Co-authored-by: 赵理山 <ls@zhaolishandeMacBook-Air.local>
This commit is contained in:
@ -872,6 +872,22 @@ impl Tensor {
|
||||
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
|
||||
}
|
||||
|
||||
pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
||||
let (n, c, h, w) = self.dims4()?;
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
||||
let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
|
||||
arg,
|
||||
kernel_size,
|
||||
stride,
|
||||
});
|
||||
let storage = self
|
||||
.storage()
|
||||
.max_pool2d(self.layout(), kernel_size, stride)?;
|
||||
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
|
||||
}
|
||||
|
||||
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
|
Reference in New Issue
Block a user