diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 612359f4..7e4675f7 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -1289,6 +1289,15 @@ impl Map2 for MatMul { } else { Parallelism::None }; + let (b, m, n, k) = if b_skip == 0 && a_skip == m * k { + // a_skip and c_skip should be updated but step is always 0 so + // it wouldn't matter. + (1, b * m, n, k) + } else if a_skip == 0 && b_skip == n * k { + (1, m, b * n, k) + } else { + (b, m, n, k) + }; for step in 0..b { let lhs_p = &lhs[step * a_skip..]; let rhs_p = &rhs[step * b_skip..];