diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs index 012456ec..3ba30f94 100644 --- a/candle-nn/examples/cpu_benchmarks.rs +++ b/candle-nn/examples/cpu_benchmarks.rs @@ -117,7 +117,11 @@ impl Benchmark for Conv2dIm2Col { let (h_out, w_out) = (h - h_k + 1, w - w_k + 1); let col = d.0.apply_op1_no_bwd(&Im2Col(h_k, w_k))?; let res = col.matmul(&d.1.flatten_from(1)?.t()?)?; - res.reshape((b, (), h_out, w_out)) + let res = res + .reshape((b, h_out, w_out, ()))? + .permute((0, 3, 1, 2))? + .contiguous()?; + Ok(res) } const ITERS: usize = 5;