Improve the handling of matmul with squeezed layouts. (#1998)

* Improve the handling of matmul with squeezed layouts.

* Fix for the cuda backend.

* Revert the temporary fix.
This commit is contained in:
Laurent Mazare
2024-04-02 23:17:05 +02:00
committed by GitHub
parent d17b2cdad9
commit 08c049def3
5 changed files with 151 additions and 139 deletions

View File

@ -1174,6 +1174,8 @@ fn gemm_config<T>(
let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[_, stride] if lhs_l.dims()[0] == 1 => stride,
[stride, _] if lhs_l.dims()[1] == 1 => stride,
[stride] => stride,
[] => m * k,
_ => Err(CudaError::MatMulNonContiguous {
@ -1184,6 +1186,8 @@ fn gemm_config<T>(
};
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[_, stride] if rhs_l.dims()[0] == 1 => stride,
[stride, _] if rhs_l.dims()[1] == 1 => stride,
[stride] => stride,
[] => n * k,
_ => Err(CudaError::MatMulNonContiguous {