Optimize the batched matmul for the cpu backend. (#2884)

This commit is contained in:
Laurent Mazare
2025-04-12 21:40:40 +02:00
committed by GitHub
parent 34505fdf3a
commit 15ed0b11ce

View File

@ -1289,6 +1289,15 @@ impl Map2 for MatMul {
} else {
Parallelism::None
};
let (b, m, n, k) = if b_skip == 0 && a_skip == m * k {
// a_skip and c_skip should be updated but step is always 0 so
// it wouldn't matter.
(1, b * m, n, k)
} else if a_skip == 0 && b_skip == n * k {
(1, m, b * n, k)
} else {
(b, m, n, k)
};
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];