Bugfix so that im2col produce the same results as conv2d. (#801)

This commit is contained in:
Laurent Mazare
2023-09-10 16:59:46 +01:00
committed by GitHub
parent 559944146f
commit 4f18180fc7

View File

@ -117,7 +117,11 @@ impl Benchmark for Conv2dIm2Col {
let (h_out, w_out) = (h - h_k + 1, w - w_k + 1);
let col = d.0.apply_op1_no_bwd(&Im2Col(h_k, w_k))?;
let res = col.matmul(&d.1.flatten_from(1)?.t()?)?;
res.reshape((b, (), h_out, w_out))
let res = res
.reshape((b, h_out, w_out, ()))?
.permute((0, 3, 1, 2))?
.contiguous()?;
Ok(res)
}
const ITERS: usize = 5;