Add some conv2d tests. (#347)

* Add some conv2d tests.

* Add a simpler conv2d test.

* More conv2d testing + bugfix.

* Add a todo.
This commit is contained in:
Laurent Mazare
2023-08-08 20:02:42 +02:00
committed by GitHub
parent 13ce68ff9b
commit 1e6dbeac01
3 changed files with 121 additions and 4 deletions

View File

@ -1037,7 +1037,7 @@ struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl<'a> Map2 for Conv2D<'a> {
const OP: &'static str = "conv2d";
fn f<T: 'static + num_traits::NumAssign + Copy>(
fn f<T: 'static + num_traits::NumAssign + Copy + std::fmt::Display>(
&self,
inp: &[T],
inp_l: &Layout,
@ -1063,11 +1063,13 @@ impl<'a> Map2 for Conv2D<'a> {
let dst_idx = dst_idx + dst_w;
let mut d = T::zero();
for offset_h in 0..p.k_h {
let src_h_plus = p.stride * dst_h + offset_h;
// TODO: Handle the case where padding is larger than p.k_h / 2.
let src_h_plus = p.stride * dst_h + offset_h + p.k_h / 2 - p.padding;
if p.k_h / 2 <= src_h_plus && src_h_plus < p.k_h / 2 + p.i_h {
let src_h = src_h_plus - p.k_h / 2;
for offset_w in 0..p.k_w {
let src_w_plus = p.stride * dst_w + offset_w;
let src_w_plus =
p.stride * dst_w + offset_w + p.k_w / 2 - p.padding;
// inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset]
if p.k_w / 2 <= src_w_plus && src_w_plus < p.k_w / 2 + p.i_w {
let src_w = src_w_plus - p.k_w / 2;