More flexible matmul contiguity checks. (#1949)

* More flexible matmul contiguity checks.

* Also relax the checks on the metal side.
This commit is contained in:
Laurent Mazare
2024-03-27 10:59:05 +01:00
committed by GitHub
parent 75b6d4b0da
commit a9abde5f93
4 changed files with 51 additions and 8 deletions

View File

@ -1451,9 +1451,12 @@ pub fn call_gemm(
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
let a_trans = if lhs_m1 == 1 && lhs_m2 == k {
// lhs has shape b, m, k
// We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element.
let a_trans = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) {
false
} else if lhs_m1 == m && lhs_m2 == 1 {
} else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 {
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {
@ -1462,9 +1465,10 @@ pub fn call_gemm(
mnk: (m, n, k),
})?;
};
let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
// rhs has shape b, k, n
let b_trans = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) {
false
} else if rhs_m1 == k && rhs_m2 == 1 {
} else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 {
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {