use crate::{CpuStorage, DType, Result, Shape, Storage}; /// A `DeviceLocation` represents a physical device whereas multiple `Device` /// can live on the same location (typically for cuda devices). #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum DeviceLocation { Cpu, Cuda { gpu_id: usize }, } #[derive(Debug, Clone)] pub enum Device { Cpu, Cuda(crate::CudaDevice), } // TODO: Should we back the cpu implementation using the NdArray crate or similar? pub trait NdArray { fn shape(&self) -> Result; fn to_cpu_storage(&self) -> CpuStorage; } impl NdArray for S { fn shape(&self) -> Result { Ok(Shape::from(())) } fn to_cpu_storage(&self) -> CpuStorage { S::to_cpu_storage(&[*self]) } } impl NdArray for &[S; N] { fn shape(&self) -> Result { Ok(Shape::from(self.len())) } fn to_cpu_storage(&self) -> CpuStorage { S::to_cpu_storage(self.as_slice()) } } impl NdArray for &[S] { fn shape(&self) -> Result { Ok(Shape::from(self.len())) } fn to_cpu_storage(&self) -> CpuStorage { S::to_cpu_storage(self) } } impl NdArray for &[[S; N]; M] { fn shape(&self) -> Result { Ok(Shape::from((M, N))) } fn to_cpu_storage(&self) -> CpuStorage { S::to_cpu_storage_owned(self.concat()) } } impl Device { pub fn new_cuda(ordinal: usize) -> Result { Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) } pub fn location(&self) -> DeviceLocation { match self { Self::Cpu => DeviceLocation::Cpu, Self::Cuda(device) => DeviceLocation::Cuda { gpu_id: device.ordinal(), }, } } pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result { match self { Device::Cpu => { let storage = CpuStorage::ones_impl(shape, dtype); Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { let storage = device.ones_impl(shape, dtype)?; Ok(Storage::Cuda(storage)) } } } pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result { match self { Device::Cpu => { let storage = CpuStorage::zeros_impl(shape, dtype); Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { let storage = device.zeros_impl(shape, dtype)?; Ok(Storage::Cuda(storage)) } } } pub(crate) fn tensor(&self, array: A) -> Result { match self { Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())), Device::Cuda(device) => { let storage = array.to_cpu_storage(); let storage = device.cuda_from_cpu_storage(&storage)?; Ok(Storage::Cuda(storage)) } } } }