Cublas fixes.

This commit is contained in:
laurent
2023-06-26 17:59:27 +01:00
parent 1ad5baecc5
commit 46789c403c

View File

@ -214,11 +214,11 @@ fn gemm_config<T>(
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
// The a tensor has dims batching, m, k
let transa = if lhs_m1 == 1 && lhs_m2 == k {
cublasOperation_t::CUBLAS_OP_N
} else if rhs_m1 == m && rhs_m2 == 1 {
cublasOperation_t::CUBLAS_OP_T
// The a tensor has dims batching, k, n (rhs)
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
(n as i32, cublasOperation_t::CUBLAS_OP_N)
} else if rhs_m1 == k && rhs_m2 == 1 {
(k as i32, cublasOperation_t::CUBLAS_OP_T)
} else {
Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
@ -226,11 +226,11 @@ fn gemm_config<T>(
mnk: (m, n, k),
})?
};
// The b tensor has dims batching, k, n
let transb = if rhs_m1 == 1 && rhs_m2 == n {
cublasOperation_t::CUBLAS_OP_N
} else if rhs_m1 == k && rhs_m2 == 1 {
cublasOperation_t::CUBLAS_OP_T
// The b tensor has dims batching, m, k (lhs)
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
(k as i32, cublasOperation_t::CUBLAS_OP_N)
} else if lhs_m1 == m && lhs_m2 == 1 {
(m as i32, cublasOperation_t::CUBLAS_OP_T)
} else {
Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
@ -246,8 +246,8 @@ fn gemm_config<T>(
m: n as i32,
n: m as i32,
k: k as i32,
lda: n as i32,
ldb: k as i32,
lda,
ldb,
ldc: n as i32,
transa,
transb,