Add the cublas handle to the cuda device.

This commit is contained in:
laurent
2023-06-22 18:03:53 +01:00
parent 7d9a8ff3f9
commit 683730c21d
3 changed files with 114 additions and 73 deletions

View File

@ -1,6 +1,7 @@
use crate::{CpuStorage, DType, Shape}; use crate::{CpuStorage, DType, Shape};
use candle_kernels as kernels; use candle_kernels as kernels;
use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig}; use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
use std::sync::Arc;
/// cudarc related errors /// cudarc related errors
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
@ -11,6 +12,9 @@ pub enum CudaError {
#[error(transparent)] #[error(transparent)]
Compiler(#[from] cudarc::nvrtc::CompileError), Compiler(#[from] cudarc::nvrtc::CompileError),
#[error(transparent)]
Cublas(#[from] cudarc::cublas::result::CublasError),
#[error("{op} only supports contiguous tensors")] #[error("{op} only supports contiguous tensors")]
RequiresContiguous { op: &'static str }, RequiresContiguous { op: &'static str },
@ -24,54 +28,77 @@ pub enum CudaError {
type Result<T> = std::result::Result<T, CudaError>; type Result<T> = std::result::Result<T, CudaError>;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct CudaDevice(std::sync::Arc<cudarc::driver::CudaDevice>); pub struct CudaDevice {
device: Arc<cudarc::driver::CudaDevice>,
#[allow(dead_code)]
blas: Arc<cudarc::cublas::CudaBlas>,
}
impl std::ops::Deref for CudaDevice {
type Target = Arc<cudarc::driver::CudaDevice>;
fn deref(&self) -> &Self::Target {
&self.device
}
}
impl CudaDevice { impl CudaDevice {
pub(crate) fn new(ordinal: usize) -> Result<Self> { pub(crate) fn new(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new(ordinal)?; let device = cudarc::driver::CudaDevice::new(ordinal)?;
Ok(Self(device)) let blas = cudarc::cublas::CudaBlas::new(device.clone())?;
Ok(Self {
device,
blas: Arc::new(blas),
})
} }
pub(crate) fn ordinal(&self) -> usize { pub(crate) fn ordinal(&self) -> usize {
self.0.ordinal() self.device.ordinal()
} }
pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> { pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count(); let elem_count = shape.elem_count();
match dtype { let slice = match dtype {
DType::F32 => { DType::F32 => {
let data = self.0.alloc_zeros::<f32>(elem_count)?; let data = self.alloc_zeros::<f32>(elem_count)?;
Ok(CudaStorage::F32(data)) CudaStorageSlice::F32(data)
} }
DType::F64 => { DType::F64 => {
let data = self.0.alloc_zeros::<f64>(elem_count)?; let data = self.alloc_zeros::<f64>(elem_count)?;
Ok(CudaStorage::F64(data)) CudaStorageSlice::F64(data)
} }
} };
Ok(CudaStorage {
slice,
device: self.clone(),
})
} }
pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> { pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count(); let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32); let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dev = &self.0; let slice = match dtype {
match dtype {
DType::F32 => { DType::F32 => {
// SAFETY: Set later by running the fill kernel. // SAFETY: Set later by running the fill kernel.
let data = unsafe { dev.alloc::<f32>(elem_count) }?; let data = unsafe { self.alloc::<f32>(elem_count) }?;
let func = self.get_or_load_func("fill_f32", kernels::FILL)?; let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
let params = (&data, v as f32, elem_count); let params = (&data, v as f32, elem_count);
unsafe { func.launch(cfg, params) }?; unsafe { func.launch(cfg, params) }?;
Ok(CudaStorage::F32(data)) CudaStorageSlice::F32(data)
} }
DType::F64 => { DType::F64 => {
// SAFETY: Set later by running the fill kernel. // SAFETY: Set later by running the fill kernel.
let data = unsafe { dev.alloc::<f64>(elem_count) }?; let data = unsafe { self.alloc::<f64>(elem_count) }?;
let func = self.get_or_load_func("fill_f64", kernels::FILL)?; let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
let params = (&data, v, elem_count); let params = (&data, v, elem_count);
unsafe { func.launch(cfg, params) }?; unsafe { func.launch(cfg, params) }?;
Ok(CudaStorage::F64(data)) CudaStorageSlice::F64(data)
} }
} };
Ok(CudaStorage {
slice,
device: self.clone(),
})
} }
pub(crate) fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> { pub(crate) fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
@ -79,16 +106,20 @@ impl CudaDevice {
} }
pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> { pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
match storage { let slice = match storage {
CpuStorage::F32(storage) => { CpuStorage::F32(storage) => {
let data = self.0.htod_sync_copy(storage)?; let data = self.htod_sync_copy(storage)?;
Ok(CudaStorage::F32(data)) CudaStorageSlice::F32(data)
} }
CpuStorage::F64(storage) => { CpuStorage::F64(storage) => {
let data = self.0.htod_sync_copy(storage)?; let data = self.htod_sync_copy(storage)?;
Ok(CudaStorage::F64(data)) CudaStorageSlice::F64(data)
} }
} };
Ok(CudaStorage {
slice,
device: self.clone(),
})
} }
fn get_or_load_func( fn get_or_load_func(
@ -96,11 +127,10 @@ impl CudaDevice {
module_name: &'static str, module_name: &'static str,
ptx: &'static str, ptx: &'static str,
) -> Result<CudaFunction> { ) -> Result<CudaFunction> {
let dev = &self.0; if !self.has_func(module_name, module_name) {
if !dev.has_func(module_name, module_name) { self.load_ptx(ptx.into(), module_name, &[module_name])?;
dev.load_ptx(ptx.into(), module_name, &[module_name])?;
} }
dev.get_func(module_name, module_name) self.get_func(module_name, module_name)
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is // Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
// able to only build the error value if needed. // able to only build the error value if needed.
.ok_or(CudaError::MissingKernel { module_name }) .ok_or(CudaError::MissingKernel { module_name })
@ -108,31 +138,36 @@ impl CudaDevice {
} }
#[derive(Debug)] #[derive(Debug)]
pub enum CudaStorage { enum CudaStorageSlice {
F32(CudaSlice<f32>), F32(CudaSlice<f32>),
F64(CudaSlice<f64>), F64(CudaSlice<f64>),
} }
#[derive(Debug)]
pub struct CudaStorage {
slice: CudaStorageSlice,
device: CudaDevice,
}
impl CudaStorage { impl CudaStorage {
pub fn try_clone(&self) -> Result<Self> { pub fn try_clone(&self) -> Result<Self> {
match self { let slice = match &self.slice {
Self::F32(slice) => Ok(Self::F32(slice.try_clone()?)), CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?),
Self::F64(slice) => Ok(Self::F64(slice.try_clone()?)), CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?),
} };
let device = self.device.clone();
Ok(Self { slice, device })
} }
pub fn dtype(&self) -> DType { pub fn dtype(&self) -> DType {
match self { match self.slice {
Self::F32(_) => DType::F32, CudaStorageSlice::F32(_) => DType::F32,
Self::F64(_) => DType::F64, CudaStorageSlice::F64(_) => DType::F64,
} }
} }
pub fn device(&self) -> CudaDevice { pub fn device(&self) -> &CudaDevice {
match self { &self.device
Self::F32(slice) => CudaDevice(slice.device()),
Self::F64(slice) => CudaDevice(slice.device()),
}
} }
pub(crate) fn affine_impl( pub(crate) fn affine_impl(
@ -146,27 +181,29 @@ impl CudaStorage {
let el_count = shape.elem_count(); let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32); let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = self.device(); let dev = self.device();
let ds = dev.0.htod_copy([dims, stride].concat())?; let ds = dev.htod_copy([dims, stride].concat())?;
match self { let slice = match &self.slice {
Self::F32(arg) => { CudaStorageSlice::F32(arg) => {
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?; let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel. // SAFETY: Set later by running the kernel.
let out = unsafe { dev.0.alloc::<f32>(el_count) }?; let out = unsafe { dev.alloc::<f32>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out, mul as f32, add as f32); let params = (el_count, dims.len(), &ds, arg, &out, mul as f32, add as f32);
// SAFETY: ffi. // SAFETY: ffi.
unsafe { func.launch(cfg, params) }?; unsafe { func.launch(cfg, params) }?;
Ok(Self::F32(out)) CudaStorageSlice::F32(out)
} }
Self::F64(arg) => { CudaStorageSlice::F64(arg) => {
let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?; let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel. // SAFETY: Set later by running the kernel.
let out = unsafe { dev.0.alloc::<f64>(el_count) }?; let out = unsafe { dev.alloc::<f64>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out, mul, add); let params = (el_count, dims.len(), &ds, arg, &out, mul, add);
// SAFETY: ffi. // SAFETY: ffi.
unsafe { func.launch(cfg, params) }?; unsafe { func.launch(cfg, params) }?;
Ok(Self::F64(out)) CudaStorageSlice::F64(out)
} }
} };
let device = dev.clone();
Ok(Self { slice, device })
} }
pub(crate) fn unary_impl<U: crate::op::UnaryOp>( pub(crate) fn unary_impl<U: crate::op::UnaryOp>(
@ -177,28 +214,30 @@ impl CudaStorage {
let dims = shape.dims(); let dims = shape.dims();
let el_count = shape.elem_count(); let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32); let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = self.device(); let dev = &self.device;
let ds = dev.0.htod_copy([dims, stride].concat())?; let ds = dev.htod_copy([dims, stride].concat())?;
match self { let slice = match &self.slice {
Self::F32(arg) => { CudaStorageSlice::F32(arg) => {
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?; let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
// SAFETY: Set later by running the kernel. // SAFETY: Set later by running the kernel.
let out = unsafe { dev.0.alloc::<f32>(el_count) }?; let out = unsafe { dev.alloc::<f32>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out); let params = (el_count, dims.len(), &ds, arg, &out);
// SAFETY: ffi. // SAFETY: ffi.
unsafe { func.launch(cfg, params) }?; unsafe { func.launch(cfg, params) }?;
Ok(Self::F32(out)) CudaStorageSlice::F32(out)
} }
Self::F64(arg) => { CudaStorageSlice::F64(arg) => {
let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?; let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?;
// SAFETY: Set later by running the kernel. // SAFETY: Set later by running the kernel.
let out = unsafe { dev.0.alloc::<f64>(el_count) }?; let out = unsafe { dev.alloc::<f64>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out); let params = (el_count, dims.len(), &ds, arg, &out);
// SAFETY: ffi. // SAFETY: ffi.
unsafe { func.launch(cfg, params) }?; unsafe { func.launch(cfg, params) }?;
Ok(Self::F64(out)) CudaStorageSlice::F64(out)
} }
} };
let device = dev.clone();
Ok(Self { slice, device })
} }
pub(crate) fn binary_impl<B: crate::op::BinaryOp>( pub(crate) fn binary_impl<B: crate::op::BinaryOp>(
@ -212,39 +251,41 @@ impl CudaStorage {
let dims = shape.dims(); let dims = shape.dims();
let cfg = LaunchConfig::for_num_elems(elem_count as u32); let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dev = self.device(); let dev = self.device();
let dims_and_strides = dev.0.htod_copy([dims, lhs_stride, rhs_stride].concat())?; let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?;
match (self, rhs) { let slice = match (&self.slice, &rhs.slice) {
(Self::F32(lhs), Self::F32(rhs)) => { (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?; let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?;
// SAFETY: Set later by running the kernel. // SAFETY: Set later by running the kernel.
let out = unsafe { dev.0.alloc::<f32>(elem_count) }?; let out = unsafe { dev.alloc::<f32>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi // SAFETY: ffi
unsafe { func.launch(cfg, params) }?; unsafe { func.launch(cfg, params) }?;
Ok(Self::F32(out)) CudaStorageSlice::F32(out)
} }
(Self::F64(lhs), Self::F64(rhs)) => { (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
// SAFETY: Set later by running the kernel. // SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?; let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?;
let out = unsafe { dev.0.alloc::<f64>(elem_count) }?; let out = unsafe { dev.alloc::<f64>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi // SAFETY: ffi
unsafe { func.launch(cfg, params) }?; unsafe { func.launch(cfg, params) }?;
Ok(Self::F64(out)) CudaStorageSlice::F64(out)
} }
// The dtypes should have been checked at this point so this is an internal error. // The dtypes should have been checked at this point so this is an internal error.
_ => Err(CudaError::InternalError("dtype mismatch in binary op")), _ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
} };
let device = dev.clone();
Ok(Self { slice, device })
} }
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> { pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
match self { match &self.slice {
Self::F32(slice) => { CudaStorageSlice::F32(slice) => {
let dev = slice.device(); let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice)?; let cpu_storage = dev.dtoh_sync_copy(slice)?;
Ok(CpuStorage::F32(cpu_storage)) Ok(CpuStorage::F32(cpu_storage))
} }
Self::F64(slice) => { CudaStorageSlice::F64(slice) => {
let dev = slice.device(); let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice)?; let cpu_storage = dev.dtoh_sync_copy(slice)?;
Ok(CpuStorage::F64(cpu_storage)) Ok(CpuStorage::F64(cpu_storage))

View File

@ -46,7 +46,7 @@ impl CudaStorage {
fail!() fail!()
} }
pub fn device(&self) -> CudaDevice { pub fn device(&self) -> &CudaDevice {
fail!() fail!()
} }

View File

@ -22,7 +22,7 @@ impl Storage {
pub fn device(&self) -> Device { pub fn device(&self) -> Device {
match self { match self {
Self::Cpu(_) => Device::Cpu, Self::Cpu(_) => Device::Cpu,
Self::Cuda(storage) => Device::Cuda(storage.device()), Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
} }
} }