mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Start adding the cublas based matmul.
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
use crate::{CpuStorage, DType, Shape};
|
||||
use candle_kernels as kernels;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -292,4 +293,77 @@ impl CudaStorage {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
use cudarc::cublas::sys::cublasOperation_t;
|
||||
let elem_count = b * m * n * k;
|
||||
let dev = &self.device;
|
||||
let slice = match (&self.slice, &rhs.slice) {
|
||||
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
||||
let gemm = GemmConfig {
|
||||
alpha: 1.,
|
||||
beta: 1.,
|
||||
m: m as i32,
|
||||
n: n as i32,
|
||||
k: k as i32,
|
||||
lda: n as i32, // TODO
|
||||
ldb: k as i32, // TODO
|
||||
ldc: n as i32, // TODO
|
||||
transa: cublasOperation_t::CUBLAS_OP_N,
|
||||
transb: cublasOperation_t::CUBLAS_OP_T,
|
||||
};
|
||||
let cfg = StridedBatchedConfig {
|
||||
batch_size: b as i32,
|
||||
gemm,
|
||||
stride_a: lhs_stride[0] as i64,
|
||||
stride_b: rhs_stride[0] as i64,
|
||||
stride_c: 42, // TODO,
|
||||
};
|
||||
let mut out = unsafe { dev.alloc::<f32>(elem_count) }?;
|
||||
unsafe {
|
||||
self.device
|
||||
.blas
|
||||
.gemm_strided_batched(cfg, lhs, rhs, &mut out)
|
||||
}?;
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
|
||||
let gemm = GemmConfig {
|
||||
alpha: 1.,
|
||||
beta: 1.,
|
||||
m: m as i32,
|
||||
n: n as i32,
|
||||
k: k as i32,
|
||||
lda: n as i32, // TODO
|
||||
ldb: k as i32, // TODO
|
||||
ldc: n as i32, // TODO
|
||||
transa: cublasOperation_t::CUBLAS_OP_N,
|
||||
transb: cublasOperation_t::CUBLAS_OP_T,
|
||||
};
|
||||
let cfg = StridedBatchedConfig {
|
||||
batch_size: b as i32,
|
||||
gemm,
|
||||
stride_a: lhs_stride[0] as i64,
|
||||
stride_b: rhs_stride[0] as i64,
|
||||
stride_c: 42, // TODO,
|
||||
};
|
||||
let mut out = unsafe { dev.alloc::<f64>(elem_count) }?;
|
||||
unsafe {
|
||||
self.device
|
||||
.blas
|
||||
.gemm_strided_batched(cfg, lhs, rhs, &mut out)
|
||||
}?;
|
||||
CudaStorageSlice::F64(out)
|
||||
}
|
||||
_ => return Err(CudaError::InternalError("dtype mismatch in matmul op")),
|
||||
};
|
||||
let device = dev.clone();
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
}
|
||||
|
@ -71,4 +71,14 @@ impl CudaStorage {
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_impl(
|
||||
&self,
|
||||
_: &Self,
|
||||
_: (usize, usize, usize, usize),
|
||||
_: &[usize],
|
||||
_: &[usize],
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
}
|
||||
|
@ -132,12 +132,13 @@ impl Storage {
|
||||
self.same_device(rhs, "matmul")?;
|
||||
self.same_dtype(rhs, "matmul")?;
|
||||
match (self, rhs) {
|
||||
(Storage::Cpu(storage), Storage::Cpu(rhs_storage)) => {
|
||||
let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?;
|
||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(_), Self::Cuda(_)) => {
|
||||
todo!()
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
|
Reference in New Issue
Block a user