add max_pool2d (#371)

Co-authored-by: 赵理山 <ls@zhaolishandeMacBook-Air.local>
This commit is contained in:
LeeeSe
2023-08-10 01:05:26 +08:00
committed by GitHub
parent 1892bd139c
commit a5c5a893aa
9 changed files with 115 additions and 0 deletions

View File

@ -674,6 +674,48 @@ impl Map1 for AvgPool2D {
}
}
struct MaxPool2D((usize, usize), (usize, usize));
impl Map1 for MaxPool2D {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
let (k_h, k_w) = self.0;
let (s_h, s_w) = self.1;
let (b_sz, c, h, w) = layout.shape().dims4()?;
let stride = layout.stride();
let (stride_h, stride_w) = (stride[2], stride[3]);
let h_out = (h - k_h) / s_h + 1;
let w_out = (w - k_w) / s_w + 1;
let src_index = layout.start_offset();
let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
for b_idx in 0..b_sz {
let dst = &mut dst[b_idx * c * h_out * w_out..];
let src_index = src_index + b_idx * stride[0];
for c_idx in 0..c {
let dst = &mut dst[c_idx * h_out * w_out..];
let src_index = src_index + c_idx * stride[1];
for h_idx in 0..h_out {
for w_idx in 0..w_out {
let mut largest =
src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
for m in 0..k_h {
for n in 0..k_w {
let m = s_h * h_idx + m;
let n = s_w * w_idx + n;
if largest < src[src_index + m * stride_h + n * stride_w] {
largest = src[src_index + m * stride_h + n * stride_w]
}
}
}
dst[h_idx * w_out + w_idx] = largest;
}
}
}
}
Ok(dst)
}
}
struct UpsampleNearest2D(usize, usize);
impl Map1 for UpsampleNearest2D {
@ -1664,6 +1706,15 @@ impl BackendStorage for CpuStorage {
AvgPool2D(kernel_size, stride).map(self, layout)
}
fn max_pool2d(
&self,
layout: &Layout,
kernel_size: (usize, usize),
stride: (usize, usize),
) -> Result<Self> {
MaxPool2D(kernel_size, stride).map(self, layout)
}
fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
UpsampleNearest2D(h, w).map(self, layout)
}