diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 1b107ecc..56fa1684 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -214,11 +214,11 @@ fn gemm_config( let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; - // The a tensor has dims batching, m, k - let transa = if lhs_m1 == 1 && lhs_m2 == k { - cublasOperation_t::CUBLAS_OP_N - } else if rhs_m1 == m && rhs_m2 == 1 { - cublasOperation_t::CUBLAS_OP_T + // The a tensor has dims batching, k, n (rhs) + let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n { + (n as i32, cublasOperation_t::CUBLAS_OP_N) + } else if rhs_m1 == k && rhs_m2 == 1 { + (k as i32, cublasOperation_t::CUBLAS_OP_T) } else { Err(CudaError::MatMulNonContiguous { lhs_stride: lhs_stride.to_vec(), @@ -226,11 +226,11 @@ fn gemm_config( mnk: (m, n, k), })? }; - // The b tensor has dims batching, k, n - let transb = if rhs_m1 == 1 && rhs_m2 == n { - cublasOperation_t::CUBLAS_OP_N - } else if rhs_m1 == k && rhs_m2 == 1 { - cublasOperation_t::CUBLAS_OP_T + // The b tensor has dims batching, m, k (lhs) + let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k { + (k as i32, cublasOperation_t::CUBLAS_OP_N) + } else if lhs_m1 == m && lhs_m2 == 1 { + (m as i32, cublasOperation_t::CUBLAS_OP_T) } else { Err(CudaError::MatMulNonContiguous { lhs_stride: lhs_stride.to_vec(), @@ -246,8 +246,8 @@ fn gemm_config( m: n as i32, n: m as i32, k: k as i32, - lda: n as i32, - ldb: k as i32, + lda, + ldb, ldc: n as i32, transa, transb,