From 683730c21df8f44d12cca3aead1766120cfff149 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 22 Jun 2023 18:03:53 +0100 Subject: [PATCH] Add the cublas handle to the cuda device. --- src/cuda_backend.rs | 183 +++++++++++++++++++++++--------------- src/dummy_cuda_backend.rs | 2 +- src/storage.rs | 2 +- 3 files changed, 114 insertions(+), 73 deletions(-) diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index c34cd3ea..36a68731 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -1,6 +1,7 @@ use crate::{CpuStorage, DType, Shape}; use candle_kernels as kernels; use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig}; +use std::sync::Arc; /// cudarc related errors #[derive(thiserror::Error, Debug)] @@ -11,6 +12,9 @@ pub enum CudaError { #[error(transparent)] Compiler(#[from] cudarc::nvrtc::CompileError), + #[error(transparent)] + Cublas(#[from] cudarc::cublas::result::CublasError), + #[error("{op} only supports contiguous tensors")] RequiresContiguous { op: &'static str }, @@ -24,54 +28,77 @@ pub enum CudaError { type Result = std::result::Result; #[derive(Debug, Clone)] -pub struct CudaDevice(std::sync::Arc); +pub struct CudaDevice { + device: Arc, + #[allow(dead_code)] + blas: Arc, +} + +impl std::ops::Deref for CudaDevice { + type Target = Arc; + + fn deref(&self) -> &Self::Target { + &self.device + } +} impl CudaDevice { pub(crate) fn new(ordinal: usize) -> Result { let device = cudarc::driver::CudaDevice::new(ordinal)?; - Ok(Self(device)) + let blas = cudarc::cublas::CudaBlas::new(device.clone())?; + Ok(Self { + device, + blas: Arc::new(blas), + }) } pub(crate) fn ordinal(&self) -> usize { - self.0.ordinal() + self.device.ordinal() } pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { let elem_count = shape.elem_count(); - match dtype { + let slice = match dtype { DType::F32 => { - let data = self.0.alloc_zeros::(elem_count)?; - Ok(CudaStorage::F32(data)) + let data = self.alloc_zeros::(elem_count)?; + CudaStorageSlice::F32(data) } DType::F64 => { - let data = self.0.alloc_zeros::(elem_count)?; - Ok(CudaStorage::F64(data)) + let data = self.alloc_zeros::(elem_count)?; + CudaStorageSlice::F64(data) } - } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) } pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result { let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); - let dev = &self.0; - match dtype { + let slice = match dtype { DType::F32 => { // SAFETY: Set later by running the fill kernel. - let data = unsafe { dev.alloc::(elem_count) }?; + let data = unsafe { self.alloc::(elem_count) }?; let func = self.get_or_load_func("fill_f32", kernels::FILL)?; let params = (&data, v as f32, elem_count); unsafe { func.launch(cfg, params) }?; - Ok(CudaStorage::F32(data)) + CudaStorageSlice::F32(data) } DType::F64 => { // SAFETY: Set later by running the fill kernel. - let data = unsafe { dev.alloc::(elem_count) }?; + let data = unsafe { self.alloc::(elem_count) }?; let func = self.get_or_load_func("fill_f64", kernels::FILL)?; let params = (&data, v, elem_count); unsafe { func.launch(cfg, params) }?; - Ok(CudaStorage::F64(data)) + CudaStorageSlice::F64(data) } - } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) } pub(crate) fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { @@ -79,16 +106,20 @@ impl CudaDevice { } pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result { - match storage { + let slice = match storage { CpuStorage::F32(storage) => { - let data = self.0.htod_sync_copy(storage)?; - Ok(CudaStorage::F32(data)) + let data = self.htod_sync_copy(storage)?; + CudaStorageSlice::F32(data) } CpuStorage::F64(storage) => { - let data = self.0.htod_sync_copy(storage)?; - Ok(CudaStorage::F64(data)) + let data = self.htod_sync_copy(storage)?; + CudaStorageSlice::F64(data) } - } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) } fn get_or_load_func( @@ -96,11 +127,10 @@ impl CudaDevice { module_name: &'static str, ptx: &'static str, ) -> Result { - let dev = &self.0; - if !dev.has_func(module_name, module_name) { - dev.load_ptx(ptx.into(), module_name, &[module_name])?; + if !self.has_func(module_name, module_name) { + self.load_ptx(ptx.into(), module_name, &[module_name])?; } - dev.get_func(module_name, module_name) + self.get_func(module_name, module_name) // Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is // able to only build the error value if needed. .ok_or(CudaError::MissingKernel { module_name }) @@ -108,31 +138,36 @@ impl CudaDevice { } #[derive(Debug)] -pub enum CudaStorage { +enum CudaStorageSlice { F32(CudaSlice), F64(CudaSlice), } +#[derive(Debug)] +pub struct CudaStorage { + slice: CudaStorageSlice, + device: CudaDevice, +} + impl CudaStorage { pub fn try_clone(&self) -> Result { - match self { - Self::F32(slice) => Ok(Self::F32(slice.try_clone()?)), - Self::F64(slice) => Ok(Self::F64(slice.try_clone()?)), - } + let slice = match &self.slice { + CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?), + CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?), + }; + let device = self.device.clone(); + Ok(Self { slice, device }) } pub fn dtype(&self) -> DType { - match self { - Self::F32(_) => DType::F32, - Self::F64(_) => DType::F64, + match self.slice { + CudaStorageSlice::F32(_) => DType::F32, + CudaStorageSlice::F64(_) => DType::F64, } } - pub fn device(&self) -> CudaDevice { - match self { - Self::F32(slice) => CudaDevice(slice.device()), - Self::F64(slice) => CudaDevice(slice.device()), - } + pub fn device(&self) -> &CudaDevice { + &self.device } pub(crate) fn affine_impl( @@ -146,27 +181,29 @@ impl CudaStorage { let el_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = self.device(); - let ds = dev.0.htod_copy([dims, stride].concat())?; - match self { - Self::F32(arg) => { + let ds = dev.htod_copy([dims, stride].concat())?; + let slice = match &self.slice { + CudaStorageSlice::F32(arg) => { let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.0.alloc::(el_count) }?; + let out = unsafe { dev.alloc::(el_count) }?; let params = (el_count, dims.len(), &ds, arg, &out, mul as f32, add as f32); // SAFETY: ffi. unsafe { func.launch(cfg, params) }?; - Ok(Self::F32(out)) + CudaStorageSlice::F32(out) } - Self::F64(arg) => { + CudaStorageSlice::F64(arg) => { let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.0.alloc::(el_count) }?; + let out = unsafe { dev.alloc::(el_count) }?; let params = (el_count, dims.len(), &ds, arg, &out, mul, add); // SAFETY: ffi. unsafe { func.launch(cfg, params) }?; - Ok(Self::F64(out)) + CudaStorageSlice::F64(out) } - } + }; + let device = dev.clone(); + Ok(Self { slice, device }) } pub(crate) fn unary_impl( @@ -177,28 +214,30 @@ impl CudaStorage { let dims = shape.dims(); let el_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); - let dev = self.device(); - let ds = dev.0.htod_copy([dims, stride].concat())?; - match self { - Self::F32(arg) => { + let dev = &self.device; + let ds = dev.htod_copy([dims, stride].concat())?; + let slice = match &self.slice { + CudaStorageSlice::F32(arg) => { let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.0.alloc::(el_count) }?; + let out = unsafe { dev.alloc::(el_count) }?; let params = (el_count, dims.len(), &ds, arg, &out); // SAFETY: ffi. unsafe { func.launch(cfg, params) }?; - Ok(Self::F32(out)) + CudaStorageSlice::F32(out) } - Self::F64(arg) => { + CudaStorageSlice::F64(arg) => { let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.0.alloc::(el_count) }?; + let out = unsafe { dev.alloc::(el_count) }?; let params = (el_count, dims.len(), &ds, arg, &out); // SAFETY: ffi. unsafe { func.launch(cfg, params) }?; - Ok(Self::F64(out)) + CudaStorageSlice::F64(out) } - } + }; + let device = dev.clone(); + Ok(Self { slice, device }) } pub(crate) fn binary_impl( @@ -212,39 +251,41 @@ impl CudaStorage { let dims = shape.dims(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); let dev = self.device(); - let dims_and_strides = dev.0.htod_copy([dims, lhs_stride, rhs_stride].concat())?; - match (self, rhs) { - (Self::F32(lhs), Self::F32(rhs)) => { + let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?; + let slice = match (&self.slice, &rhs.slice) { + (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.0.alloc::(elem_count) }?; + let out = unsafe { dev.alloc::(elem_count) }?; let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); // SAFETY: ffi unsafe { func.launch(cfg, params) }?; - Ok(Self::F32(out)) + CudaStorageSlice::F32(out) } - (Self::F64(lhs), Self::F64(rhs)) => { + (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { // SAFETY: Set later by running the kernel. let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?; - let out = unsafe { dev.0.alloc::(elem_count) }?; + let out = unsafe { dev.alloc::(elem_count) }?; let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); // SAFETY: ffi unsafe { func.launch(cfg, params) }?; - Ok(Self::F64(out)) + CudaStorageSlice::F64(out) } // The dtypes should have been checked at this point so this is an internal error. - _ => Err(CudaError::InternalError("dtype mismatch in binary op")), - } + _ => return Err(CudaError::InternalError("dtype mismatch in binary op")), + }; + let device = dev.clone(); + Ok(Self { slice, device }) } pub(crate) fn to_cpu_storage(&self) -> Result { - match self { - Self::F32(slice) => { + match &self.slice { + CudaStorageSlice::F32(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice)?; Ok(CpuStorage::F32(cpu_storage)) } - Self::F64(slice) => { + CudaStorageSlice::F64(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice)?; Ok(CpuStorage::F64(cpu_storage)) diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs index 939247cd..4bc59f61 100644 --- a/src/dummy_cuda_backend.rs +++ b/src/dummy_cuda_backend.rs @@ -46,7 +46,7 @@ impl CudaStorage { fail!() } - pub fn device(&self) -> CudaDevice { + pub fn device(&self) -> &CudaDevice { fail!() } diff --git a/src/storage.rs b/src/storage.rs index b4c2c272..4bc24149 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -22,7 +22,7 @@ impl Storage { pub fn device(&self) -> Device { match self { Self::Cpu(_) => Device::Cpu, - Self::Cuda(storage) => Device::Cuda(storage.device()), + Self::Cuda(storage) => Device::Cuda(storage.device().clone()), } }