Handle transposed matrixes in cublas.

This commit is contained in:
laurent
2023-06-26 17:49:29 +01:00
parent 3761f02aa8
commit 1ad5baecc5
3 changed files with 46 additions and 17 deletions

View File

@ -1,4 +1,5 @@
// TODO: Use a proper distributed reduction rather than atomicAdd.
// https://people.maths.ox.ac.uk/gilesm/cuda/prac4/reduction.pdf
#include "cuda_utils.cuh"
#include<stdint.h>

View File

@ -25,6 +25,13 @@ pub enum CudaError {
#[error("internal error '{0}'")]
InternalError(&'static str),
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
MatMulNonContiguous {
lhs_stride: Vec<usize>,
rhs_stride: Vec<usize>,
mnk: (usize, usize, usize),
},
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
UnexpectedDType {
msg: &'static str,
@ -197,12 +204,40 @@ fn gemm_config<T>(
alpha: T,
beta: T,
(b, m, n, k): (usize, usize, usize, usize),
_lhs_stride: &[usize],
_rhs_stride: &[usize],
) -> StridedBatchedConfig<T> {
// TODO: Handle lhs_stride and rhs_stride.
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<StridedBatchedConfig<T>> {
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm
use cudarc::cublas::sys::cublasOperation_t;
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
// The a tensor has dims batching, m, k
let transa = if lhs_m1 == 1 && lhs_m2 == k {
cublasOperation_t::CUBLAS_OP_N
} else if rhs_m1 == m && rhs_m2 == 1 {
cublasOperation_t::CUBLAS_OP_T
} else {
Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?
};
// The b tensor has dims batching, k, n
let transb = if rhs_m1 == 1 && rhs_m2 == n {
cublasOperation_t::CUBLAS_OP_N
} else if rhs_m1 == k && rhs_m2 == 1 {
cublasOperation_t::CUBLAS_OP_T
} else {
Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?
};
// The setup below was copied from:
// https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531
let gemm = GemmConfig {
@ -214,16 +249,16 @@ fn gemm_config<T>(
lda: n as i32,
ldb: k as i32,
ldc: n as i32,
transa: cublasOperation_t::CUBLAS_OP_N,
transb: cublasOperation_t::CUBLAS_OP_N,
transa,
transb,
};
StridedBatchedConfig {
Ok(StridedBatchedConfig {
batch_size: b as i32,
gemm,
stride_a: (m * k) as i64,
stride_b: (n * k) as i64,
stride_c: (m * n) as i64,
}
})
}
impl CudaStorage {
@ -580,7 +615,7 @@ impl CudaStorage {
let dev = &self.device;
let slice = match (&self.slice, &rhs.slice) {
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride);
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
let mut out = unsafe { dev.alloc::<f32>(elem_count) }?;
unsafe {
self.device
@ -590,7 +625,7 @@ impl CudaStorage {
CudaStorageSlice::F32(out)
}
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride);
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
let mut out = unsafe { dev.alloc::<f64>(elem_count) }?;
unsafe {
self.device

View File

@ -445,13 +445,6 @@ impl Tensor {
op: "matmul",
});
}
if self.device().is_cuda() && (!self.is_contiguous() || !rhs.is_contiguous()) {
// It looks like the cublas implementation of XgemmStridedBatched only supports
// non-standard strides on the batch dimension.
return Err(Error::RequiresContiguous {
op: "matmul-cublas",
});
}
let m = a_dims[dim - 2];
let k = a_dims[dim - 1];