From 0671b8c3699e434b08c8af15700d7b1a541fe560 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 22 Jun 2023 19:24:02 +0100 Subject: [PATCH] Improve the gemm config. --- src/cuda_backend.rs | 71 ++++++++++++++++++++------------------------- 1 file changed, 32 insertions(+), 39 deletions(-) diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 433e6a93..806d6c26 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -150,6 +150,36 @@ pub struct CudaStorage { device: CudaDevice, } +fn gemm_config( + alpha: T, + beta: T, + (b, m, n, k): (usize, usize, usize, usize), + lhs_stride: &[usize], + rhs_stride: &[usize], +) -> StridedBatchedConfig { + // https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm + use cudarc::cublas::sys::cublasOperation_t; + let gemm = GemmConfig { + alpha, + beta, + m: m as i32, + n: n as i32, + k: k as i32, + lda: lhs_stride[lhs_stride.len() - 2] as i32, + ldb: rhs_stride[rhs_stride.len() - 2] as i32, + ldc: m as i32, + transa: cublasOperation_t::CUBLAS_OP_N, + transb: cublasOperation_t::CUBLAS_OP_N, + }; + StridedBatchedConfig { + batch_size: b as i32, + gemm, + stride_a: lhs_stride[0] as i64, + stride_b: rhs_stride[0] as i64, + stride_c: (m * n * k) as i64, + } +} + impl CudaStorage { pub fn try_clone(&self) -> Result { let slice = match &self.slice { @@ -301,30 +331,11 @@ impl CudaStorage { lhs_stride: &[usize], rhs_stride: &[usize], ) -> Result { - use cudarc::cublas::sys::cublasOperation_t; let elem_count = b * m * n * k; let dev = &self.device; let slice = match (&self.slice, &rhs.slice) { (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { - let gemm = GemmConfig { - alpha: 1., - beta: 1., - m: m as i32, - n: n as i32, - k: k as i32, - lda: n as i32, // TODO - ldb: k as i32, // TODO - ldc: n as i32, // TODO - transa: cublasOperation_t::CUBLAS_OP_N, - transb: cublasOperation_t::CUBLAS_OP_T, - }; - let cfg = StridedBatchedConfig { - batch_size: b as i32, - gemm, - stride_a: lhs_stride[0] as i64, - stride_b: rhs_stride[0] as i64, - stride_c: 42, // TODO, - }; + let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride); let mut out = unsafe { dev.alloc::(elem_count) }?; unsafe { self.device @@ -334,25 +345,7 @@ impl CudaStorage { CudaStorageSlice::F32(out) } (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { - let gemm = GemmConfig { - alpha: 1., - beta: 1., - m: m as i32, - n: n as i32, - k: k as i32, - lda: n as i32, // TODO - ldb: k as i32, // TODO - ldc: n as i32, // TODO - transa: cublasOperation_t::CUBLAS_OP_N, - transb: cublasOperation_t::CUBLAS_OP_T, - }; - let cfg = StridedBatchedConfig { - batch_size: b as i32, - gemm, - stride_a: lhs_stride[0] as i64, - stride_b: rhs_stride[0] as i64, - stride_c: 42, // TODO, - }; + let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride); let mut out = unsafe { dev.alloc::(elem_count) }?; unsafe { self.device