From cc78900922a8185742192f985c927d5a877ef86b Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 22 Jun 2023 18:45:10 +0100 Subject: [PATCH] Start adding the cublas based matmul. --- src/cuda_backend.rs | 74 +++++++++++++++++++++++++++++++++++++++ src/dummy_cuda_backend.rs | 10 ++++++ src/storage.rs | 9 ++--- 3 files changed, 89 insertions(+), 4 deletions(-) diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 36a68731..433e6a93 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -1,5 +1,6 @@ use crate::{CpuStorage, DType, Shape}; use candle_kernels as kernels; +use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig}; use std::sync::Arc; @@ -292,4 +293,77 @@ impl CudaStorage { } } } + + pub(crate) fn matmul_impl( + &self, + rhs: &Self, + (b, m, n, k): (usize, usize, usize, usize), + lhs_stride: &[usize], + rhs_stride: &[usize], + ) -> Result { + use cudarc::cublas::sys::cublasOperation_t; + let elem_count = b * m * n * k; + let dev = &self.device; + let slice = match (&self.slice, &rhs.slice) { + (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { + let gemm = GemmConfig { + alpha: 1., + beta: 1., + m: m as i32, + n: n as i32, + k: k as i32, + lda: n as i32, // TODO + ldb: k as i32, // TODO + ldc: n as i32, // TODO + transa: cublasOperation_t::CUBLAS_OP_N, + transb: cublasOperation_t::CUBLAS_OP_T, + }; + let cfg = StridedBatchedConfig { + batch_size: b as i32, + gemm, + stride_a: lhs_stride[0] as i64, + stride_b: rhs_stride[0] as i64, + stride_c: 42, // TODO, + }; + let mut out = unsafe { dev.alloc::(elem_count) }?; + unsafe { + self.device + .blas + .gemm_strided_batched(cfg, lhs, rhs, &mut out) + }?; + CudaStorageSlice::F32(out) + } + (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { + let gemm = GemmConfig { + alpha: 1., + beta: 1., + m: m as i32, + n: n as i32, + k: k as i32, + lda: n as i32, // TODO + ldb: k as i32, // TODO + ldc: n as i32, // TODO + transa: cublasOperation_t::CUBLAS_OP_N, + transb: cublasOperation_t::CUBLAS_OP_T, + }; + let cfg = StridedBatchedConfig { + batch_size: b as i32, + gemm, + stride_a: lhs_stride[0] as i64, + stride_b: rhs_stride[0] as i64, + stride_c: 42, // TODO, + }; + let mut out = unsafe { dev.alloc::(elem_count) }?; + unsafe { + self.device + .blas + .gemm_strided_batched(cfg, lhs, rhs, &mut out) + }?; + CudaStorageSlice::F64(out) + } + _ => return Err(CudaError::InternalError("dtype mismatch in matmul op")), + }; + let device = dev.clone(); + Ok(Self { slice, device }) + } } diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs index 4bc59f61..f8669494 100644 --- a/src/dummy_cuda_backend.rs +++ b/src/dummy_cuda_backend.rs @@ -71,4 +71,14 @@ impl CudaStorage { ) -> Result { Err(Error::NotCompiledWithCudaSupport) } + + pub(crate) fn matmul_impl( + &self, + _: &Self, + _: (usize, usize, usize, usize), + _: &[usize], + _: &[usize], + ) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } } diff --git a/src/storage.rs b/src/storage.rs index 4bc24149..9f8cd2d5 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -132,12 +132,13 @@ impl Storage { self.same_device(rhs, "matmul")?; self.same_dtype(rhs, "matmul")?; match (self, rhs) { - (Storage::Cpu(storage), Storage::Cpu(rhs_storage)) => { - let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?; + (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { + let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?; Ok(Self::Cpu(storage)) } - (Self::Cuda(_), Self::Cuda(_)) => { - todo!() + (Self::Cuda(lhs), Self::Cuda(rhs)) => { + let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?; + Ok(Self::Cuda(storage)) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(),