mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Handle transposed matrixes in cublas.
This commit is contained in:
@ -1,4 +1,5 @@
|
|||||||
// TODO: Use a proper distributed reduction rather than atomicAdd.
|
// TODO: Use a proper distributed reduction rather than atomicAdd.
|
||||||
|
// https://people.maths.ox.ac.uk/gilesm/cuda/prac4/reduction.pdf
|
||||||
#include "cuda_utils.cuh"
|
#include "cuda_utils.cuh"
|
||||||
#include<stdint.h>
|
#include<stdint.h>
|
||||||
|
|
||||||
|
@ -25,6 +25,13 @@ pub enum CudaError {
|
|||||||
#[error("internal error '{0}'")]
|
#[error("internal error '{0}'")]
|
||||||
InternalError(&'static str),
|
InternalError(&'static str),
|
||||||
|
|
||||||
|
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||||
|
MatMulNonContiguous {
|
||||||
|
lhs_stride: Vec<usize>,
|
||||||
|
rhs_stride: Vec<usize>,
|
||||||
|
mnk: (usize, usize, usize),
|
||||||
|
},
|
||||||
|
|
||||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||||
UnexpectedDType {
|
UnexpectedDType {
|
||||||
msg: &'static str,
|
msg: &'static str,
|
||||||
@ -197,12 +204,40 @@ fn gemm_config<T>(
|
|||||||
alpha: T,
|
alpha: T,
|
||||||
beta: T,
|
beta: T,
|
||||||
(b, m, n, k): (usize, usize, usize, usize),
|
(b, m, n, k): (usize, usize, usize, usize),
|
||||||
_lhs_stride: &[usize],
|
lhs_stride: &[usize],
|
||||||
_rhs_stride: &[usize],
|
rhs_stride: &[usize],
|
||||||
) -> StridedBatchedConfig<T> {
|
) -> Result<StridedBatchedConfig<T>> {
|
||||||
// TODO: Handle lhs_stride and rhs_stride.
|
|
||||||
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm
|
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm
|
||||||
use cudarc::cublas::sys::cublasOperation_t;
|
use cudarc::cublas::sys::cublasOperation_t;
|
||||||
|
|
||||||
|
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||||
|
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||||
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
|
// The a tensor has dims batching, m, k
|
||||||
|
let transa = if lhs_m1 == 1 && lhs_m2 == k {
|
||||||
|
cublasOperation_t::CUBLAS_OP_N
|
||||||
|
} else if rhs_m1 == m && rhs_m2 == 1 {
|
||||||
|
cublasOperation_t::CUBLAS_OP_T
|
||||||
|
} else {
|
||||||
|
Err(CudaError::MatMulNonContiguous {
|
||||||
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
|
mnk: (m, n, k),
|
||||||
|
})?
|
||||||
|
};
|
||||||
|
// The b tensor has dims batching, k, n
|
||||||
|
let transb = if rhs_m1 == 1 && rhs_m2 == n {
|
||||||
|
cublasOperation_t::CUBLAS_OP_N
|
||||||
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
|
cublasOperation_t::CUBLAS_OP_T
|
||||||
|
} else {
|
||||||
|
Err(CudaError::MatMulNonContiguous {
|
||||||
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
|
mnk: (m, n, k),
|
||||||
|
})?
|
||||||
|
};
|
||||||
// The setup below was copied from:
|
// The setup below was copied from:
|
||||||
// https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531
|
// https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531
|
||||||
let gemm = GemmConfig {
|
let gemm = GemmConfig {
|
||||||
@ -214,16 +249,16 @@ fn gemm_config<T>(
|
|||||||
lda: n as i32,
|
lda: n as i32,
|
||||||
ldb: k as i32,
|
ldb: k as i32,
|
||||||
ldc: n as i32,
|
ldc: n as i32,
|
||||||
transa: cublasOperation_t::CUBLAS_OP_N,
|
transa,
|
||||||
transb: cublasOperation_t::CUBLAS_OP_N,
|
transb,
|
||||||
};
|
};
|
||||||
StridedBatchedConfig {
|
Ok(StridedBatchedConfig {
|
||||||
batch_size: b as i32,
|
batch_size: b as i32,
|
||||||
gemm,
|
gemm,
|
||||||
stride_a: (m * k) as i64,
|
stride_a: (m * k) as i64,
|
||||||
stride_b: (n * k) as i64,
|
stride_b: (n * k) as i64,
|
||||||
stride_c: (m * n) as i64,
|
stride_c: (m * n) as i64,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CudaStorage {
|
impl CudaStorage {
|
||||||
@ -580,7 +615,7 @@ impl CudaStorage {
|
|||||||
let dev = &self.device;
|
let dev = &self.device;
|
||||||
let slice = match (&self.slice, &rhs.slice) {
|
let slice = match (&self.slice, &rhs.slice) {
|
||||||
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
||||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride);
|
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
|
||||||
let mut out = unsafe { dev.alloc::<f32>(elem_count) }?;
|
let mut out = unsafe { dev.alloc::<f32>(elem_count) }?;
|
||||||
unsafe {
|
unsafe {
|
||||||
self.device
|
self.device
|
||||||
@ -590,7 +625,7 @@ impl CudaStorage {
|
|||||||
CudaStorageSlice::F32(out)
|
CudaStorageSlice::F32(out)
|
||||||
}
|
}
|
||||||
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
|
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
|
||||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride);
|
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
|
||||||
let mut out = unsafe { dev.alloc::<f64>(elem_count) }?;
|
let mut out = unsafe { dev.alloc::<f64>(elem_count) }?;
|
||||||
unsafe {
|
unsafe {
|
||||||
self.device
|
self.device
|
||||||
|
@ -445,13 +445,6 @@ impl Tensor {
|
|||||||
op: "matmul",
|
op: "matmul",
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
if self.device().is_cuda() && (!self.is_contiguous() || !rhs.is_contiguous()) {
|
|
||||||
// It looks like the cublas implementation of XgemmStridedBatched only supports
|
|
||||||
// non-standard strides on the batch dimension.
|
|
||||||
return Err(Error::RequiresContiguous {
|
|
||||||
op: "matmul-cublas",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
let m = a_dims[dim - 2];
|
let m = a_dims[dim - 2];
|
||||||
let k = a_dims[dim - 1];
|
let k = a_dims[dim - 1];
|
||||||
|
Reference in New Issue
Block a user