From 4f18180fc7564b9a02cab13c4acda0ac7b17a799 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 10 Sep 2023 16:59:46 +0100 Subject: [PATCH] Bugfix so that im2col produce the same results as conv2d. (#801) --- candle-nn/examples/cpu_benchmarks.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs index 012456ec..3ba30f94 100644 --- a/candle-nn/examples/cpu_benchmarks.rs +++ b/candle-nn/examples/cpu_benchmarks.rs @@ -117,7 +117,11 @@ impl Benchmark for Conv2dIm2Col { let (h_out, w_out) = (h - h_k + 1, w - w_k + 1); let col = d.0.apply_op1_no_bwd(&Im2Col(h_k, w_k))?; let res = col.matmul(&d.1.flatten_from(1)?.t()?)?; - res.reshape((b, (), h_out, w_out)) + let res = res + .reshape((b, h_out, w_out, ()))? + .permute((0, 3, 1, 2))? + .contiguous()?; + Ok(res) } const ITERS: usize = 5;