mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Start adding the cublas based matmul.
This commit is contained in:
@ -1,5 +1,6 @@
|
|||||||
use crate::{CpuStorage, DType, Shape};
|
use crate::{CpuStorage, DType, Shape};
|
||||||
use candle_kernels as kernels;
|
use candle_kernels as kernels;
|
||||||
|
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||||
use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
|
use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@ -292,4 +293,77 @@ impl CudaStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn matmul_impl(
|
||||||
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
(b, m, n, k): (usize, usize, usize, usize),
|
||||||
|
lhs_stride: &[usize],
|
||||||
|
rhs_stride: &[usize],
|
||||||
|
) -> Result<Self> {
|
||||||
|
use cudarc::cublas::sys::cublasOperation_t;
|
||||||
|
let elem_count = b * m * n * k;
|
||||||
|
let dev = &self.device;
|
||||||
|
let slice = match (&self.slice, &rhs.slice) {
|
||||||
|
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
||||||
|
let gemm = GemmConfig {
|
||||||
|
alpha: 1.,
|
||||||
|
beta: 1.,
|
||||||
|
m: m as i32,
|
||||||
|
n: n as i32,
|
||||||
|
k: k as i32,
|
||||||
|
lda: n as i32, // TODO
|
||||||
|
ldb: k as i32, // TODO
|
||||||
|
ldc: n as i32, // TODO
|
||||||
|
transa: cublasOperation_t::CUBLAS_OP_N,
|
||||||
|
transb: cublasOperation_t::CUBLAS_OP_T,
|
||||||
|
};
|
||||||
|
let cfg = StridedBatchedConfig {
|
||||||
|
batch_size: b as i32,
|
||||||
|
gemm,
|
||||||
|
stride_a: lhs_stride[0] as i64,
|
||||||
|
stride_b: rhs_stride[0] as i64,
|
||||||
|
stride_c: 42, // TODO,
|
||||||
|
};
|
||||||
|
let mut out = unsafe { dev.alloc::<f32>(elem_count) }?;
|
||||||
|
unsafe {
|
||||||
|
self.device
|
||||||
|
.blas
|
||||||
|
.gemm_strided_batched(cfg, lhs, rhs, &mut out)
|
||||||
|
}?;
|
||||||
|
CudaStorageSlice::F32(out)
|
||||||
|
}
|
||||||
|
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
|
||||||
|
let gemm = GemmConfig {
|
||||||
|
alpha: 1.,
|
||||||
|
beta: 1.,
|
||||||
|
m: m as i32,
|
||||||
|
n: n as i32,
|
||||||
|
k: k as i32,
|
||||||
|
lda: n as i32, // TODO
|
||||||
|
ldb: k as i32, // TODO
|
||||||
|
ldc: n as i32, // TODO
|
||||||
|
transa: cublasOperation_t::CUBLAS_OP_N,
|
||||||
|
transb: cublasOperation_t::CUBLAS_OP_T,
|
||||||
|
};
|
||||||
|
let cfg = StridedBatchedConfig {
|
||||||
|
batch_size: b as i32,
|
||||||
|
gemm,
|
||||||
|
stride_a: lhs_stride[0] as i64,
|
||||||
|
stride_b: rhs_stride[0] as i64,
|
||||||
|
stride_c: 42, // TODO,
|
||||||
|
};
|
||||||
|
let mut out = unsafe { dev.alloc::<f64>(elem_count) }?;
|
||||||
|
unsafe {
|
||||||
|
self.device
|
||||||
|
.blas
|
||||||
|
.gemm_strided_batched(cfg, lhs, rhs, &mut out)
|
||||||
|
}?;
|
||||||
|
CudaStorageSlice::F64(out)
|
||||||
|
}
|
||||||
|
_ => return Err(CudaError::InternalError("dtype mismatch in matmul op")),
|
||||||
|
};
|
||||||
|
let device = dev.clone();
|
||||||
|
Ok(Self { slice, device })
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -71,4 +71,14 @@ impl CudaStorage {
|
|||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn matmul_impl(
|
||||||
|
&self,
|
||||||
|
_: &Self,
|
||||||
|
_: (usize, usize, usize, usize),
|
||||||
|
_: &[usize],
|
||||||
|
_: &[usize],
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -132,12 +132,13 @@ impl Storage {
|
|||||||
self.same_device(rhs, "matmul")?;
|
self.same_device(rhs, "matmul")?;
|
||||||
self.same_dtype(rhs, "matmul")?;
|
self.same_dtype(rhs, "matmul")?;
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Storage::Cpu(storage), Storage::Cpu(rhs_storage)) => {
|
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||||
let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?;
|
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||||
Ok(Self::Cpu(storage))
|
Ok(Self::Cpu(storage))
|
||||||
}
|
}
|
||||||
(Self::Cuda(_), Self::Cuda(_)) => {
|
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||||
todo!()
|
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||||
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
|
Reference in New Issue
Block a user