mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Handle transposed matrixes in cublas.
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
// TODO: Use a proper distributed reduction rather than atomicAdd.
|
||||
// https://people.maths.ox.ac.uk/gilesm/cuda/prac4/reduction.pdf
|
||||
#include "cuda_utils.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
|
@ -25,6 +25,13 @@ pub enum CudaError {
|
||||
#[error("internal error '{0}'")]
|
||||
InternalError(&'static str),
|
||||
|
||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||
MatMulNonContiguous {
|
||||
lhs_stride: Vec<usize>,
|
||||
rhs_stride: Vec<usize>,
|
||||
mnk: (usize, usize, usize),
|
||||
},
|
||||
|
||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||
UnexpectedDType {
|
||||
msg: &'static str,
|
||||
@ -197,12 +204,40 @@ fn gemm_config<T>(
|
||||
alpha: T,
|
||||
beta: T,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
_lhs_stride: &[usize],
|
||||
_rhs_stride: &[usize],
|
||||
) -> StridedBatchedConfig<T> {
|
||||
// TODO: Handle lhs_stride and rhs_stride.
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<StridedBatchedConfig<T>> {
|
||||
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm
|
||||
use cudarc::cublas::sys::cublasOperation_t;
|
||||
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
// The a tensor has dims batching, m, k
|
||||
let transa = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
cublasOperation_t::CUBLAS_OP_N
|
||||
} else if rhs_m1 == m && rhs_m2 == 1 {
|
||||
cublasOperation_t::CUBLAS_OP_T
|
||||
} else {
|
||||
Err(CudaError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?
|
||||
};
|
||||
// The b tensor has dims batching, k, n
|
||||
let transb = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
cublasOperation_t::CUBLAS_OP_N
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
cublasOperation_t::CUBLAS_OP_T
|
||||
} else {
|
||||
Err(CudaError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?
|
||||
};
|
||||
// The setup below was copied from:
|
||||
// https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531
|
||||
let gemm = GemmConfig {
|
||||
@ -214,16 +249,16 @@ fn gemm_config<T>(
|
||||
lda: n as i32,
|
||||
ldb: k as i32,
|
||||
ldc: n as i32,
|
||||
transa: cublasOperation_t::CUBLAS_OP_N,
|
||||
transb: cublasOperation_t::CUBLAS_OP_N,
|
||||
transa,
|
||||
transb,
|
||||
};
|
||||
StridedBatchedConfig {
|
||||
Ok(StridedBatchedConfig {
|
||||
batch_size: b as i32,
|
||||
gemm,
|
||||
stride_a: (m * k) as i64,
|
||||
stride_b: (n * k) as i64,
|
||||
stride_c: (m * n) as i64,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
impl CudaStorage {
|
||||
@ -580,7 +615,7 @@ impl CudaStorage {
|
||||
let dev = &self.device;
|
||||
let slice = match (&self.slice, &rhs.slice) {
|
||||
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride);
|
||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
|
||||
let mut out = unsafe { dev.alloc::<f32>(elem_count) }?;
|
||||
unsafe {
|
||||
self.device
|
||||
@ -590,7 +625,7 @@ impl CudaStorage {
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
|
||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride);
|
||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
|
||||
let mut out = unsafe { dev.alloc::<f64>(elem_count) }?;
|
||||
unsafe {
|
||||
self.device
|
||||
|
@ -445,13 +445,6 @@ impl Tensor {
|
||||
op: "matmul",
|
||||
});
|
||||
}
|
||||
if self.device().is_cuda() && (!self.is_contiguous() || !rhs.is_contiguous()) {
|
||||
// It looks like the cublas implementation of XgemmStridedBatched only supports
|
||||
// non-standard strides on the batch dimension.
|
||||
return Err(Error::RequiresContiguous {
|
||||
op: "matmul-cublas",
|
||||
});
|
||||
}
|
||||
|
||||
let m = a_dims[dim - 2];
|
||||
let k = a_dims[dim - 1];
|
||||
|
Reference in New Issue
Block a user