From 1ad5baecc5bba71a80d50f35a079743461e1fef4 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 17:49:29 +0100 Subject: [PATCH] Handle transposed matrixes in cublas. --- kernels/src/reduce.cu | 1 + src/cuda_backend.rs | 55 +++++++++++++++++++++++++++++++++++-------- src/tensor.rs | 7 ------ 3 files changed, 46 insertions(+), 17 deletions(-) diff --git a/kernels/src/reduce.cu b/kernels/src/reduce.cu index 1d6ee436..d12d6b22 100644 --- a/kernels/src/reduce.cu +++ b/kernels/src/reduce.cu @@ -1,4 +1,5 @@ // TODO: Use a proper distributed reduction rather than atomicAdd. +// https://people.maths.ox.ac.uk/gilesm/cuda/prac4/reduction.pdf #include "cuda_utils.cuh" #include diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 2c96cc6b..1b107ecc 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -25,6 +25,13 @@ pub enum CudaError { #[error("internal error '{0}'")] InternalError(&'static str), + #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] + MatMulNonContiguous { + lhs_stride: Vec, + rhs_stride: Vec, + mnk: (usize, usize, usize), + }, + #[error("{msg}, expected: {expected:?}, got: {got:?}")] UnexpectedDType { msg: &'static str, @@ -197,12 +204,40 @@ fn gemm_config( alpha: T, beta: T, (b, m, n, k): (usize, usize, usize, usize), - _lhs_stride: &[usize], - _rhs_stride: &[usize], -) -> StridedBatchedConfig { - // TODO: Handle lhs_stride and rhs_stride. + lhs_stride: &[usize], + rhs_stride: &[usize], +) -> Result> { // https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm use cudarc::cublas::sys::cublasOperation_t; + + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + // The a tensor has dims batching, m, k + let transa = if lhs_m1 == 1 && lhs_m2 == k { + cublasOperation_t::CUBLAS_OP_N + } else if rhs_m1 == m && rhs_m2 == 1 { + cublasOperation_t::CUBLAS_OP_T + } else { + Err(CudaError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })? + }; + // The b tensor has dims batching, k, n + let transb = if rhs_m1 == 1 && rhs_m2 == n { + cublasOperation_t::CUBLAS_OP_N + } else if rhs_m1 == k && rhs_m2 == 1 { + cublasOperation_t::CUBLAS_OP_T + } else { + Err(CudaError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })? + }; // The setup below was copied from: // https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531 let gemm = GemmConfig { @@ -214,16 +249,16 @@ fn gemm_config( lda: n as i32, ldb: k as i32, ldc: n as i32, - transa: cublasOperation_t::CUBLAS_OP_N, - transb: cublasOperation_t::CUBLAS_OP_N, + transa, + transb, }; - StridedBatchedConfig { + Ok(StridedBatchedConfig { batch_size: b as i32, gemm, stride_a: (m * k) as i64, stride_b: (n * k) as i64, stride_c: (m * n) as i64, - } + }) } impl CudaStorage { @@ -580,7 +615,7 @@ impl CudaStorage { let dev = &self.device; let slice = match (&self.slice, &rhs.slice) { (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { - let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride); + let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?; let mut out = unsafe { dev.alloc::(elem_count) }?; unsafe { self.device @@ -590,7 +625,7 @@ impl CudaStorage { CudaStorageSlice::F32(out) } (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { - let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride); + let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?; let mut out = unsafe { dev.alloc::(elem_count) }?; unsafe { self.device diff --git a/src/tensor.rs b/src/tensor.rs index 6a47ef4c..d177795f 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -445,13 +445,6 @@ impl Tensor { op: "matmul", }); } - if self.device().is_cuda() && (!self.is_contiguous() || !rhs.is_contiguous()) { - // It looks like the cublas implementation of XgemmStridedBatched only supports - // non-standard strides on the batch dimension. - return Err(Error::RequiresContiguous { - op: "matmul-cublas", - }); - } let m = a_dims[dim - 2]; let k = a_dims[dim - 1];