From 2231c717d5569015b30062877478983f6a120d04 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 22 Jun 2023 21:11:41 +0100 Subject: [PATCH] Fix the matmul example. --- src/cuda_backend.rs | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 7a22df4f..8704077c 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -154,21 +154,23 @@ fn gemm_config( alpha: T, beta: T, (b, m, n, k): (usize, usize, usize, usize), - lhs_stride: &[usize], - rhs_stride: &[usize], + _lhs_stride: &[usize], + _rhs_stride: &[usize], ) -> StridedBatchedConfig { + // TODO: Handle lhs_stride and rhs_stride. // https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm use cudarc::cublas::sys::cublasOperation_t; - println!("{:?} {:?} {:?}", lhs_stride, rhs_stride, (b, m, n, k)); + // The setup below was copied from: + // https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531 let gemm = GemmConfig { alpha, beta, - m: m as i32, - n: n as i32, + m: n as i32, + n: m as i32, k: k as i32, - lda: m as i32, + lda: n as i32, ldb: k as i32, - ldc: m as i32, + ldc: n as i32, transa: cublasOperation_t::CUBLAS_OP_N, transb: cublasOperation_t::CUBLAS_OP_N, }; @@ -341,7 +343,7 @@ impl CudaStorage { unsafe { self.device .blas - .gemm_strided_batched(cfg, lhs, rhs, &mut out) + .gemm_strided_batched(cfg, rhs, lhs, &mut out) }?; CudaStorageSlice::F32(out) } @@ -351,7 +353,7 @@ impl CudaStorage { unsafe { self.device .blas - .gemm_strided_batched(cfg, lhs, rhs, &mut out) + .gemm_strided_batched(cfg, rhs, lhs, &mut out) }?; CudaStorageSlice::F64(out) }