Start adding the cublas based matmul.

This commit is contained in:
laurent
2023-06-22 18:45:10 +01:00
parent 683730c21d
commit cc78900922
3 changed files with 89 additions and 4 deletions

View File

@ -1,5 +1,6 @@
use crate::{CpuStorage, DType, Shape};
use candle_kernels as kernels;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
use std::sync::Arc;
@ -292,4 +293,77 @@ impl CudaStorage {
}
}
}
pub(crate) fn matmul_impl(
&self,
rhs: &Self,
(b, m, n, k): (usize, usize, usize, usize),
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
use cudarc::cublas::sys::cublasOperation_t;
let elem_count = b * m * n * k;
let dev = &self.device;
let slice = match (&self.slice, &rhs.slice) {
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
let gemm = GemmConfig {
alpha: 1.,
beta: 1.,
m: m as i32,
n: n as i32,
k: k as i32,
lda: n as i32, // TODO
ldb: k as i32, // TODO
ldc: n as i32, // TODO
transa: cublasOperation_t::CUBLAS_OP_N,
transb: cublasOperation_t::CUBLAS_OP_T,
};
let cfg = StridedBatchedConfig {
batch_size: b as i32,
gemm,
stride_a: lhs_stride[0] as i64,
stride_b: rhs_stride[0] as i64,
stride_c: 42, // TODO,
};
let mut out = unsafe { dev.alloc::<f32>(elem_count) }?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, lhs, rhs, &mut out)
}?;
CudaStorageSlice::F32(out)
}
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
let gemm = GemmConfig {
alpha: 1.,
beta: 1.,
m: m as i32,
n: n as i32,
k: k as i32,
lda: n as i32, // TODO
ldb: k as i32, // TODO
ldc: n as i32, // TODO
transa: cublasOperation_t::CUBLAS_OP_N,
transb: cublasOperation_t::CUBLAS_OP_T,
};
let cfg = StridedBatchedConfig {
batch_size: b as i32,
gemm,
stride_a: lhs_stride[0] as i64,
stride_b: rhs_stride[0] as i64,
stride_c: 42, // TODO,
};
let mut out = unsafe { dev.alloc::<f64>(elem_count) }?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, lhs, rhs, &mut out)
}?;
CudaStorageSlice::F64(out)
}
_ => return Err(CudaError::InternalError("dtype mismatch in matmul op")),
};
let device = dev.clone();
Ok(Self { slice, device })
}
}

View File

@ -71,4 +71,14 @@ impl CudaStorage {
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn matmul_impl(
&self,
_: &Self,
_: (usize, usize, usize, usize),
_: &[usize],
_: &[usize],
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
}

View File

@ -132,12 +132,13 @@ impl Storage {
self.same_device(rhs, "matmul")?;
self.same_dtype(rhs, "matmul")?;
match (self, rhs) {
(Storage::Cpu(storage), Storage::Cpu(rhs_storage)) => {
let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?;
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(_), Self::Cuda(_)) => {
todo!()
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),