From 64264d97c17e2acff6fc6f6fcab7038b8d3f3e40 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 11 Jul 2023 11:17:02 +0100 Subject: [PATCH] Modular backends (#138) * Add some trait to formalize backends. * Use the generic backend trait. --- candle-core/src/backend.rs | 71 +++ candle-core/src/conv.rs | 2 +- candle-core/src/cuda_backend.rs | 599 +++++++++++++------------- candle-core/src/device.rs | 13 +- candle-core/src/dummy_cuda_backend.rs | 130 +++--- candle-core/src/error.rs | 2 +- candle-core/src/lib.rs | 5 +- candle-core/src/storage.rs | 1 + candle-core/src/tensor.rs | 7 +- 9 files changed, 457 insertions(+), 373 deletions(-) create mode 100644 candle-core/src/backend.rs diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs new file mode 100644 index 00000000..aa35703d --- /dev/null +++ b/candle-core/src/backend.rs @@ -0,0 +1,71 @@ +use crate::{CpuStorage, DType, Layout, Result, Shape}; + +pub(crate) trait BackendStorage: Sized { + type Device: BackendDevice; + + fn try_clone(&self, _: &Layout) -> Result; + + fn dtype(&self) -> DType; + + fn device(&self) -> &Self::Device; + + fn to_cpu_storage(&self) -> Result; + + fn affine(&self, _: &Layout, _: f64, _: f64) -> Result; + + fn elu(&self, _: &Layout, _: f64) -> Result; + + fn sum(&self, _: &Layout, _: &[usize]) -> Result; + + fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()>; + + fn to_dtype(&self, _: &Layout, _: DType) -> Result; + + fn unary_impl(&self, _: &Layout) -> Result; + + fn binary_impl(&self, _: &Self, _: &Layout, _: &Layout) + -> Result; + + fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result; + + fn conv1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConv1D, + ) -> Result; + + fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result; + + fn matmul( + &self, + _: &Self, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + ) -> Result; + + fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>; +} + +pub(crate) trait BackendDevice: Sized + std::fmt::Debug + Clone { + type Storage: BackendStorage; + + // TODO: Make the usize generic and part of a generic DeviceLocation. + fn new(_: usize) -> Result; + + fn location(&self) -> crate::DeviceLocation; + + fn same_device(&self, _: &Self) -> bool; + + fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result; + + fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result; + + fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result; + + fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result; + + fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result; +} diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index 041bb6fb..4cf9d0ad 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -1,5 +1,5 @@ #[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct ParamsConv1D { +pub struct ParamsConv1D { pub(crate) b_size: Option, // Maybe we should have a version without l_in as this bit depends on the input and not only on // the weights. diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index c2d01a07..73bc9e34 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1,4 +1,5 @@ -use crate::{CpuStorage, DType, Layout, Shape, WithDType}; +use crate::backend::{BackendDevice, BackendStorage}; +use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType}; use candle_kernels as kernels; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{ @@ -22,9 +23,6 @@ pub enum CudaError { #[error(transparent)] Curand(#[from] cudarc::curand::result::CurandError), - #[error("{op} only supports contiguous tensors")] - RequiresContiguous { op: &'static str }, - #[error("missing kernel '{module_name}'")] MissingKernel { module_name: String }, @@ -58,7 +56,11 @@ pub enum CudaError { }, } -type Result = std::result::Result; +impl From for crate::Error { + fn from(val: CudaError) -> Self { + crate::Error::Cuda(Box::new(val)) + } +} /// Unique identifier for cuda devices. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] @@ -98,220 +100,67 @@ impl std::ops::Deref for CudaDevice { } } +trait WrapErr { + fn w(self) -> std::result::Result; +} + +impl> WrapErr for std::result::Result { + fn w(self) -> std::result::Result { + self.map_err(|e| crate::Error::Cuda(Box::new(e.into()))) + } +} + impl CudaDevice { - pub(crate) fn new(ordinal: usize) -> Result { - let device = cudarc::driver::CudaDevice::new(ordinal)?; - let blas = cudarc::cublas::CudaBlas::new(device.clone())?; - let curand = cudarc::curand::CudaRng::new(299792458, device.clone())?; - Ok(Self { - id: DeviceId::new(), - device, - blas: Arc::new(blas), - curand: Arc::new(Mutex::new(CudaRng(curand))), - }) - } - - pub(crate) fn same_id(&self, rhs: &Self) -> bool { - self.id == rhs.id - } - - pub(crate) fn ordinal(&self) -> usize { - self.device.ordinal() - } - - pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { - let elem_count = shape.elem_count(); - let slice = match dtype { - DType::U8 => { - let data = self.alloc_zeros::(elem_count)?; - CudaStorageSlice::U8(data) - } - DType::U32 => { - let data = self.alloc_zeros::(elem_count)?; - CudaStorageSlice::U32(data) - } - DType::BF16 => { - let data = self.alloc_zeros::(elem_count)?; - CudaStorageSlice::BF16(data) - } - DType::F16 => { - let data = self.alloc_zeros::(elem_count)?; - CudaStorageSlice::F16(data) - } - DType::F32 => { - let data = self.alloc_zeros::(elem_count)?; - CudaStorageSlice::F32(data) - } - DType::F64 => { - let data = self.alloc_zeros::(elem_count)?; - CudaStorageSlice::F64(data) - } - }; - Ok(CudaStorage { - slice, - device: self.clone(), - }) - } - - pub(crate) fn rand_uniform( - &self, - shape: &Shape, - dtype: DType, - lo: f64, - up: f64, - ) -> Result { - let elem_count = shape.elem_count(); - let curand = self.curand.lock().unwrap(); - let slice = match dtype { - DType::U8 | DType::U32 | DType::F16 | DType::BF16 => { - Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_uniform", - })? - } - DType::F32 => { - let mut data = unsafe { self.alloc::(elem_count) }?; - curand.0.fill_with_uniform(&mut data)?; - CudaStorageSlice::F32(data) - } - DType::F64 => { - let mut data = unsafe { self.alloc::(elem_count) }?; - curand.0.fill_with_uniform(&mut data)?; - CudaStorageSlice::F64(data) - } - }; - if lo != 0.0 || up != 1.0 { - let layout = Layout::contiguous(shape); - Affine(up - lo, lo).map(&slice, self, &layout)?; - } - Ok(CudaStorage { - slice, - device: self.clone(), - }) - } - - pub(crate) fn rand_normal( - &self, - shape: &Shape, - dtype: DType, - mean: f64, - std: f64, - ) -> Result { - let elem_count = shape.elem_count(); - let curand = self.curand.lock().unwrap(); - let slice = match dtype { - DType::U8 | DType::U32 | DType::F16 | DType::BF16 => { - Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_normal", - })? - } - DType::F32 => { - let mut data = unsafe { self.alloc::(elem_count) }?; - curand - .0 - .fill_with_normal(&mut data, mean as f32, std as f32)?; - CudaStorageSlice::F32(data) - } - DType::F64 => { - let mut data = unsafe { self.alloc::(elem_count) }?; - curand.0.fill_with_normal(&mut data, mean, std)?; - CudaStorageSlice::F64(data) - } - }; - Ok(CudaStorage { - slice, - device: self.clone(), - }) - } - - pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result { + fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result { let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); let slice = match dtype { DType::U8 => { // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }?; + let data = unsafe { self.alloc::(elem_count) }.w()?; let func = self.get_or_load_func("fill_u8", kernels::FILL)?; let params = (&data, v as u8, elem_count); - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U8(data) } DType::U32 => { // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }?; + let data = unsafe { self.alloc::(elem_count) }.w()?; let func = self.get_or_load_func("fill_u32", kernels::FILL)?; let params = (&data, v as u32, elem_count); - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U32(data) } DType::BF16 => { // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }?; + let data = unsafe { self.alloc::(elem_count) }.w()?; let func = self.get_or_load_func("fill_bf16", kernels::FILL)?; let params = (&data, bf16::from_f64(v), elem_count); - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::BF16(data) } DType::F16 => { // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }?; + let data = unsafe { self.alloc::(elem_count) }.w()?; let func = self.get_or_load_func("fill_f16", kernels::FILL)?; let params = (&data, f16::from_f64(v), elem_count); - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F16(data) } DType::F32 => { // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }?; + let data = unsafe { self.alloc::(elem_count) }.w()?; let func = self.get_or_load_func("fill_f32", kernels::FILL)?; let params = (&data, v as f32, elem_count); - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F32(data) } DType::F64 => { // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }?; + let data = unsafe { self.alloc::(elem_count) }.w()?; let func = self.get_or_load_func("fill_f64", kernels::FILL)?; let params = (&data, v, elem_count); - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F64(data) - } - }; - Ok(CudaStorage { - slice, - device: self.clone(), - }) - } - - pub(crate) fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { - self.const_impl(1., shape, dtype) - } - - pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result { - let slice = match storage { - CpuStorage::U8(storage) => { - let data = self.htod_sync_copy(storage)?; - CudaStorageSlice::U8(data) - } - CpuStorage::U32(storage) => { - let data = self.htod_sync_copy(storage)?; - CudaStorageSlice::U32(data) - } - CpuStorage::BF16(storage) => { - let data = self.htod_sync_copy(storage)?; - CudaStorageSlice::BF16(data) - } - CpuStorage::F16(storage) => { - let data = self.htod_sync_copy(storage)?; - CudaStorageSlice::F16(data) - } - CpuStorage::F32(storage) => { - let data = self.htod_sync_copy(storage)?; - CudaStorageSlice::F32(data) - } - CpuStorage::F64(storage) => { - let data = self.htod_sync_copy(storage)?; + unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F64(data) } }; @@ -330,7 +179,8 @@ impl CudaDevice { .map_err(|cuda| CudaError::Load { cuda, module_name: module_name.to_string(), - })?; + }) + .w()?; } self.get_func(module_name, module_name) // Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is @@ -338,6 +188,163 @@ impl CudaDevice { .ok_or(CudaError::MissingKernel { module_name: module_name.to_string(), }) + .w() + } +} + +impl BackendDevice for CudaDevice { + type Storage = CudaStorage; + + fn new(ordinal: usize) -> Result { + let device = cudarc::driver::CudaDevice::new(ordinal).w()?; + let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?; + let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?; + Ok(Self { + id: DeviceId::new(), + device, + blas: Arc::new(blas), + curand: Arc::new(Mutex::new(CudaRng(curand))), + }) + } + + fn location(&self) -> crate::DeviceLocation { + crate::DeviceLocation::Cuda { + gpu_id: self.device.ordinal(), + } + } + + fn same_device(&self, rhs: &Self) -> bool { + self.id == rhs.id + } + + fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { + let elem_count = shape.elem_count(); + let slice = match dtype { + DType::U8 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::U8(data) + } + DType::U32 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::U32(data) + } + DType::BF16 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::BF16(data) + } + DType::F16 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::F16(data) + } + DType::F32 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result { + let elem_count = shape.elem_count(); + let curand = self.curand.lock().unwrap(); + let slice = match dtype { + DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_uniform", + }) + .w()?, + DType::F32 => { + let mut data = unsafe { self.alloc::(elem_count) }.w()?; + curand.0.fill_with_uniform(&mut data).w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let mut data = unsafe { self.alloc::(elem_count) }.w()?; + curand.0.fill_with_uniform(&mut data).w()?; + CudaStorageSlice::F64(data) + } + }; + if lo != 0.0 || up != 1.0 { + let layout = Layout::contiguous(shape); + Affine(up - lo, lo).map(&slice, self, &layout)?; + } + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result { + let elem_count = shape.elem_count(); + let curand = self.curand.lock().unwrap(); + let slice = match dtype { + DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_normal", + }) + .w()?, + DType::F32 => { + let mut data = unsafe { self.alloc::(elem_count) }.w()?; + curand + .0 + .fill_with_normal(&mut data, mean as f32, std as f32) + .w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let mut data = unsafe { self.alloc::(elem_count) }.w()?; + curand.0.fill_with_normal(&mut data, mean, std).w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { + self.const_impl(1., shape, dtype) + } + + fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { + let slice = match storage { + CpuStorage::U8(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::U8(data) + } + CpuStorage::U32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::U32(data) + } + CpuStorage::BF16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::BF16(data) + } + CpuStorage::F16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F16(data) + } + CpuStorage::F32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F32(data) + } + CpuStorage::F64(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) } } @@ -391,7 +398,7 @@ trait Map2 { (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?), - _ => return Err(CudaError::InternalError("dtype mismatch in binary op")), + _ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?, }; Ok(out) } @@ -405,7 +412,7 @@ impl Map1 for Clone { _: &CudaDevice, _: &Layout, ) -> Result> { - Ok(s.try_clone()?) + s.try_clone().w() } } @@ -426,11 +433,11 @@ impl Map1 for Affine { let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); - let ds = dev.htod_copy([dims, layout.stride()].concat())?; + let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("affine"), kernels::AFFINE)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el) }?; + let out = unsafe { dev.alloc::(el) }.w()?; let params = ( el, dims.len(), @@ -441,7 +448,7 @@ impl Map1 for Affine { T::from_f64(self.1), ); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } @@ -458,14 +465,14 @@ impl Map1 for Elu { let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); - let ds = dev.htod_copy([dims, layout.stride()].concat())?; + let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("uelu"), kernels::UNARY)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el) }?; + let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } @@ -495,13 +502,15 @@ impl<'a> Map1 for Sum<'a> { .map(|&d| src_dims[d + 1..].iter().product::()) .collect(); let cfg = LaunchConfig::for_num_elems(el as u32); - let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?; + let ds = dev + .htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat()) + .w()?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("sum"), kernels::REDUCE)?; - let out = dev.alloc_zeros::(dst_el)?; + let out = dev.alloc_zeros::(dst_el).w()?; let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } @@ -544,13 +553,15 @@ impl<'a> Map1 for FastSum<'a> { block_dim: (block_dim as u32, 1, 1), shared_mem_bytes: 0, }; - let ds = dev.htod_copy([dims.as_slice(), stride.as_slice()].concat())?; + let ds = dev + .htod_copy([dims.as_slice(), stride.as_slice()].concat()) + .w()?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("fast_sum"), kernels::REDUCE)?; - let out = dev.alloc_zeros::(dst_el)?; + let out = dev.alloc_zeros::(dst_el).w()?; let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } @@ -566,14 +577,14 @@ impl Map1 for U { let dims = shape.dims(); let el_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); - let ds = dev.htod_copy([dims, layout.stride()].concat())?; + let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::UNARY)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; + let out = unsafe { dev.alloc::(el_count) }.w()?; let params = (el_count, dims.len(), &ds, src, &out); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } @@ -593,25 +604,27 @@ impl<'a> Map1 for Embedding<'a> { msg: "embedding ids should be u32", expected: DType::U32, got: self.0.dtype(), - })?, + }) + .w()?, }; let ids = &ids; let shape = ids_l.shape(); let (v_size, h_size) = rhs_l .shape() .r2() - .map_err(|e| CudaError::WrappedError(Box::new(e)))?; + .map_err(|e| CudaError::WrappedError(Box::new(e))) + .w()?; let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); - let ds = dev.htod_copy([dims, ids_l.stride()].concat())?; + let ds = dev.htod_copy([dims, ids_l.stride()].concat()).w()?; let rhs = &rhs.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("emb"), kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el * h_size) }?; + let out = unsafe { dev.alloc::(el * h_size) }.w()?; let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } @@ -640,7 +653,7 @@ impl<'a> Map2 for Conv1D<'a> { let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::("conv1d"), kernels::CONV)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(dst_el) }?; + let out = unsafe { dev.alloc::(dst_el) }.w()?; let ds = if dims.len() == 3 { [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat() } else if dims.len() == 2 { @@ -648,10 +661,10 @@ impl<'a> Map2 for Conv1D<'a> { } else { panic!("unexpected input shape for conv1d {dims:?}") }; - let ds = dev.htod_copy(ds)?; + let ds = dev.htod_copy(ds).w()?; let params = (el, l_out, p.stride, &ds, inp, k, &out); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } @@ -673,23 +686,25 @@ impl<'a> Map2 for WhereCond<'a> { msg: "where conditions should be u32", expected: DType::U32, got: self.0.dtype(), - })?, + }) + .w()?, }; let ids = &ids; let shape = ids_l.shape(); let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); - let ds = - dev.htod_copy([dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())?; + let ds = dev + .htod_copy([dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat()) + .w()?; let t = &t.slice(layout_t.start_offset()..); let f = &f.slice(layout_f.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("where"), kernels::TERNARY)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el) }?; + let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, ids, t, f, &out); // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } @@ -707,15 +722,17 @@ impl Map2 for U { let dims = shape.dims(); let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); - let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?; + let dims_and_strides = dev + .htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat()) + .w()?; let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::BINARY)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(elem_count) }?; + let out = unsafe { dev.alloc::(elem_count) }.w()?; let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } @@ -771,7 +788,8 @@ fn gemm_config( lhs_stride: lhs_stride.to_vec(), rhs_stride: rhs_stride.to_vec(), mnk: (m, n, k), - })? + }) + .w()? }; // The b tensor has dims batching, m, k (lhs) let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k { @@ -783,7 +801,8 @@ fn gemm_config( lhs_stride: lhs_stride.to_vec(), rhs_stride: rhs_stride.to_vec(), mnk: (m, n, k), - })? + }) + .w()? }; // The setup below was copied from: // https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531 @@ -808,7 +827,8 @@ fn gemm_config( lhs_stride: lhs_stride.to_vec(), rhs_stride: rhs_stride.to_vec(), mnk: (m, n, k), - })?, + }) + .w()?, }; let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] { [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, @@ -818,7 +838,8 @@ fn gemm_config( lhs_stride: lhs_stride.to_vec(), rhs_stride: rhs_stride.to_vec(), mnk: (m, n, k), - })?, + }) + .w()?, }; Ok(StridedBatchedConfig { @@ -830,14 +851,16 @@ fn gemm_config( }) } -impl CudaStorage { - pub fn try_clone(&self, layout: &Layout) -> Result { +impl BackendStorage for CudaStorage { + type Device = CudaDevice; + + fn try_clone(&self, layout: &Layout) -> Result { let slice = Clone.map(&self.slice, self.device(), layout)?; let device = self.device.clone(); Ok(Self { slice, device }) } - pub fn dtype(&self) -> DType { + fn dtype(&self) -> DType { match self.slice { CudaStorageSlice::U8(_) => DType::U8, CudaStorageSlice::U32(_) => DType::U32, @@ -848,18 +871,18 @@ impl CudaStorage { } } - pub fn device(&self) -> &CudaDevice { + fn device(&self) -> &CudaDevice { &self.device } - pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { + fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { use cudarc::driver::DevicePtr; let shape = layout.shape(); let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let dev = self.device(); - let ds = dev.htod_copy([dims, layout.stride()].concat())?; + let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; let start_o = layout.start_offset(); // This returns an i64 rather than a &i64, this is useful to get around some temporary // lifetime issue and is safe as long as self.slice does not go out of scope before inp @@ -878,39 +901,39 @@ impl CudaStorage { let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?; let slice = match dtype { DType::U8 => { - let out = unsafe { dev.alloc::(el) }?; + let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U8(out) } DType::U32 => { - let out = unsafe { dev.alloc::(el) }?; + let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U32(out) } DType::BF16 => { - let out = unsafe { dev.alloc::(el) }?; + let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::BF16(out) } DType::F16 => { - let out = unsafe { dev.alloc::(el) }?; + let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F16(out) } DType::F32 => { - let out = unsafe { dev.alloc::(el) }?; + let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F32(out) } DType::F64 => { - let out = unsafe { dev.alloc::(el) }?; + let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F64(out) } }; @@ -920,37 +943,35 @@ impl CudaStorage { }) } - pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { + fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { let device = self.device().clone(); let slice = Affine(mul, add).map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } - pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result { + fn elu(&self, layout: &Layout, alpha: f64) -> Result { let device = self.device().clone(); let slice = Elu(alpha).map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } - pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result { + fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result { let device = self.device().clone(); let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } - pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> { - Err(CudaError::InternalError( - "TODO: implement divide_by_sum_over_dim", - )) + fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> { + Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into()) } - pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { + fn unary_impl(&self, layout: &Layout) -> Result { let device = self.device().clone(); let slice = U::V.map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } - pub(crate) fn binary_impl( + fn binary_impl( &self, rhs: &Self, lhs_l: &Layout, @@ -961,42 +982,42 @@ impl CudaStorage { Ok(Self { slice, device }) } - pub(crate) fn to_cpu_storage(&self) -> Result { + fn to_cpu_storage(&self) -> Result { match &self.slice { CudaStorageSlice::U8(slice) => { let dev = slice.device(); - let cpu_storage = dev.dtoh_sync_copy(slice)?; + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::U8(cpu_storage)) } CudaStorageSlice::U32(slice) => { let dev = slice.device(); - let cpu_storage = dev.dtoh_sync_copy(slice)?; + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::U32(cpu_storage)) } CudaStorageSlice::BF16(slice) => { let dev = slice.device(); - let cpu_storage = dev.dtoh_sync_copy(slice)?; + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::BF16(cpu_storage)) } CudaStorageSlice::F16(slice) => { let dev = slice.device(); - let cpu_storage = dev.dtoh_sync_copy(slice)?; + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::F16(cpu_storage)) } CudaStorageSlice::F32(slice) => { let dev = slice.device(); - let cpu_storage = dev.dtoh_sync_copy(slice)?; + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::F32(cpu_storage)) } CudaStorageSlice::F64(slice) => { let dev = slice.device(); - let cpu_storage = dev.dtoh_sync_copy(slice)?; + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::F64(cpu_storage)) } } } - pub(crate) fn where_cond( + fn where_cond( &self, layout: &Layout, t: &Self, @@ -1009,7 +1030,7 @@ impl CudaStorage { Ok(Self { slice, device }) } - pub(crate) fn conv1d( + fn conv1d( &self, l: &Layout, kernel: &Self, @@ -1021,13 +1042,13 @@ impl CudaStorage { Ok(Self { slice, device }) } - pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { + fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { let device = self.device().clone(); let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?; Ok(Self { slice, device }) } - pub(crate) fn matmul( + fn matmul( &self, rhs: &Self, (b, m, n, k): (usize, usize, usize, usize), @@ -1041,146 +1062,144 @@ impl CudaStorage { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?; - let mut out = unsafe { dev.alloc::(elem_count) }?; + let mut out = unsafe { dev.alloc::(elem_count) }.w()?; unsafe { self.device .blas .gemm_strided_batched(cfg, rhs, lhs, &mut out) - }?; + } + .w()?; CudaStorageSlice::BF16(out) } (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?; - let mut out = unsafe { dev.alloc::(elem_count) }?; + let mut out = unsafe { dev.alloc::(elem_count) }.w()?; unsafe { self.device .blas .gemm_strided_batched(cfg, rhs, lhs, &mut out) - }?; + } + .w()?; CudaStorageSlice::F16(out) } (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?; - let mut out = unsafe { dev.alloc::(elem_count) }?; + let mut out = unsafe { dev.alloc::(elem_count) }.w()?; unsafe { self.device .blas .gemm_strided_batched(cfg, rhs, lhs, &mut out) - }?; + } + .w()?; CudaStorageSlice::F32(out) } (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?; - let mut out = unsafe { dev.alloc::(elem_count) }?; + let mut out = unsafe { dev.alloc::(elem_count) }.w()?; unsafe { self.device .blas .gemm_strided_batched(cfg, rhs, lhs, &mut out) - }?; + } + .w()?; CudaStorageSlice::F64(out) } - _ => return Err(CudaError::InternalError("dtype mismatch in matmul op")), + _ => Err(CudaError::InternalError("dtype mismatch in matmul op")).w()?, }; let device = dev.clone(); Ok(Self { slice, device }) } - pub(crate) fn copy_strided_src( - &self, - dst: &mut Self, - dst_offset: usize, - src_l: &Layout, - ) -> Result<()> { + fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { let src_shape = src_l.shape(); let dims = src_shape.dims(); let el_count = src_shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = &self.device; - let ds = dev.htod_copy([dims, src_l.stride()].concat())?; + let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?; match (&self.slice, &mut dst.slice) { (CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.dtod_copy(&src, &mut dst)? + dev.dtod_copy(&src, &mut dst).w()? } else { let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?; // SAFETY: Set later by running the kernel. let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }? + unsafe { func.launch(cfg, params) }.w()? } } (CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.dtod_copy(&src, &mut dst)? + dev.dtod_copy(&src, &mut dst).w()? } else { let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?; // SAFETY: Set later by running the kernel. let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }? + unsafe { func.launch(cfg, params) }.w()? } } (CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.dtod_copy(&src, &mut dst)? + dev.dtod_copy(&src, &mut dst).w()? } else { let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?; // SAFETY: Set later by running the kernel. let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }? + unsafe { func.launch(cfg, params) }.w()? } } (CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.dtod_copy(&src, &mut dst)? + dev.dtod_copy(&src, &mut dst).w()? } else { let func = dev.get_or_load_func("ucopy_u8", kernels::UNARY)?; // SAFETY: Set later by running the kernel. let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }? + unsafe { func.launch(cfg, params) }.w()? } } (CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.dtod_copy(&src, &mut dst)? + dev.dtod_copy(&src, &mut dst).w()? } else { let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?; // SAFETY: Set later by running the kernel. let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }? + unsafe { func.launch(cfg, params) }.w()? } } (CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.dtod_copy(&src, &mut dst)? + dev.dtod_copy(&src, &mut dst).w()? } else { let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?; // SAFETY: Set later by running the kernel. let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }.w()?; } } - _ => { - return Err(CudaError::InternalError( - "dtype mismatch in copy_strided op", - )) - } + _ => Err(CudaError::InternalError( + "dtype mismatch in copy_strided op", + )) + .w()?, } Ok(()) } diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 1380cbc9..b428922b 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -1,3 +1,4 @@ +use crate::backend::BackendDevice; use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType}; /// A `DeviceLocation` represents a physical device whereas multiple `Device` @@ -85,10 +86,10 @@ impl Device { Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) } - pub fn same_id(&self, rhs: &Self) -> bool { + pub fn same_device(&self, rhs: &Self) -> bool { match (self, rhs) { (Self::Cpu, Self::Cpu) => true, - (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_id(rhs), + (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs), _ => false, } } @@ -96,9 +97,7 @@ impl Device { pub fn location(&self) -> DeviceLocation { match self { Self::Cpu => DeviceLocation::Cpu, - Self::Cuda(device) => DeviceLocation::Cuda { - gpu_id: device.ordinal(), - }, + Self::Cuda(device) => device.location(), } } @@ -178,7 +177,7 @@ impl Device { Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())), Device::Cuda(device) => { let storage = array.to_cpu_storage(); - let storage = device.cuda_from_cpu_storage(&storage)?; + let storage = device.storage_from_cpu_storage(&storage)?; Ok(Storage::Cuda(storage)) } } @@ -189,7 +188,7 @@ impl Device { Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))), Device::Cuda(device) => { let storage = S::to_cpu_storage_owned(data); - let storage = device.cuda_from_cpu_storage(&storage)?; + let storage = device.storage_from_cpu_storage(&storage)?; Ok(Storage::Cuda(storage)) } } diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index f5c80fcf..a81dda57 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -1,98 +1,62 @@ #![allow(dead_code)] use crate::{CpuStorage, DType, Error, Layout, Result, Shape}; -#[derive(thiserror::Error, Debug)] -pub enum DummyError {} -pub type CudaError = DummyError; - #[derive(Debug, Clone)] pub struct CudaDevice; +#[derive(Debug)] +pub struct CudaStorage; + macro_rules! fail { () => { unimplemented!("cuda support has not been enabled") }; } -impl CudaDevice { - pub(crate) fn new(_: usize) -> Result { +impl crate::backend::BackendStorage for CudaStorage { + type Device = CudaDevice; + + fn try_clone(&self, _: &Layout) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn same_id(&self, _: &Self) -> bool { - true - } - - pub(crate) fn ordinal(&self) -> usize { + fn dtype(&self) -> DType { fail!() } - pub(crate) fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result { - Err(Error::NotCompiledWithCudaSupport) - } - - pub(crate) fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result { - Err(Error::NotCompiledWithCudaSupport) - } - - pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result { - Err(Error::NotCompiledWithCudaSupport) - } - - pub(crate) fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { - Err(Error::NotCompiledWithCudaSupport) - } - - pub(crate) fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { - Err(Error::NotCompiledWithCudaSupport) - } -} - -#[derive(Debug)] -pub struct CudaStorage; - -impl CudaStorage { - pub fn try_clone(&self, _: &Layout) -> Result { - Err(Error::NotCompiledWithCudaSupport) - } - - pub fn dtype(&self) -> DType { + fn device(&self) -> &Self::Device { fail!() } - pub fn device(&self) -> &CudaDevice { - fail!() - } - - pub(crate) fn to_cpu_storage(&self) -> Result { + fn to_cpu_storage(&self) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn affine(&self, _: &Layout, _: f64, _: f64) -> Result { + fn affine(&self, _: &Layout, _: f64, _: f64) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn elu(&self, _: &Layout, _: f64) -> Result { + fn elu(&self, _: &Layout, _: f64) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn sum(&self, _: &Layout, _: &[usize]) -> Result { + fn sum(&self, _: &Layout, _: &[usize]) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> { + fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn to_dtype(&self, _: &Layout, _: DType) -> Result { + fn to_dtype(&self, _: &Layout, _: DType) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn unary_impl(&self, _: &Layout) -> Result { + fn unary_impl(&self, _: &Layout) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn binary_impl( + fn binary_impl( &self, _: &Self, _: &Layout, @@ -101,32 +65,25 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn where_cond( + fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn conv1d( &self, _: &Layout, _: &Self, _: &Layout, - _: &Self, - _: &Layout, + _: &crate::conv::ParamsConv1D, ) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn conv1d( - &self, - _l: &Layout, - _kernel: &Self, - _kernel_l: &Layout, - _params: &crate::conv::ParamsConv1D, - ) -> Result { + fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result { - Err(Error::NotCompiledWithCudaSupport) - } - - pub(crate) fn matmul( + fn matmul( &self, _: &Self, _: (usize, usize, usize, usize), @@ -136,7 +93,42 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { + fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } +} + +impl crate::backend::BackendDevice for CudaDevice { + type Storage = CudaStorage; + fn new(_: usize) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn location(&self) -> crate::DeviceLocation { + fail!() + } + + fn same_device(&self, _: &Self) -> bool { + fail!() + } + + fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { Err(Error::NotCompiledWithCudaSupport) } } diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index d8f3b4b4..caad3e1f 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -100,7 +100,7 @@ pub enum Error { }, #[error(transparent)] - Cuda(#[from] crate::CudaError), + Cuda(Box), #[error(transparent)] TryFromIntError(#[from] core::num::TryFromIntError), diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index d36f90af..06fc87d1 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -33,6 +33,7 @@ //! //! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers) +mod backend; mod backprop; mod conv; mod cpu_backend; @@ -68,10 +69,10 @@ use strided_index::StridedIndex; pub use tensor::{Tensor, TensorId}; #[cfg(feature = "cuda")] -pub use cuda_backend::{CudaDevice, CudaError, CudaStorage}; +pub use cuda_backend::{CudaDevice, CudaStorage}; #[cfg(not(feature = "cuda"))] -pub use dummy_cuda_backend::{CudaDevice, CudaError, CudaStorage}; +pub use dummy_cuda_backend::{CudaDevice, CudaStorage}; #[cfg(feature = "mkl")] extern crate intel_mkl_src; diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index ee12eeb8..5f92172d 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -1,3 +1,4 @@ +use crate::backend::BackendStorage; use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape}; // We do not want to implement Clone on Storage as cloning may fail because of diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index ecc018f9..5d4e106f 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,3 +1,4 @@ +use crate::backend::{BackendDevice, BackendStorage}; use crate::shape::Dim; use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::Arc; @@ -963,19 +964,19 @@ impl Tensor { /// If the target device is the same as the tensor device, only a shallow copy is performed. pub fn to_device(&self, device: &Device) -> Result { - if self.device().same_id(device) { + if self.device().same_device(device) { Ok(self.clone()) } else { let storage = match (self.storage.as_ref(), device) { (Storage::Cpu(storage), Device::Cuda(cuda)) => { - Storage::Cuda(cuda.cuda_from_cpu_storage(storage)?) + Storage::Cuda(cuda.storage_from_cpu_storage(storage)?) } (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), (Storage::Cuda(storage), Device::Cuda(cuda)) => { // TODO: Avoid passing through the cpu storage here, especially if the gpu ids // are the same. let cpu_storage = storage.to_cpu_storage()?; - Storage::Cuda(cuda.cuda_from_cpu_storage(&cpu_storage)?) + Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?) } (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()), };