From 8cde0c54788d7ae7c676e4f2fad5fcbc16f6980c Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 21 Jun 2023 09:13:57 +0100 Subject: [PATCH] Add some skeleton code for GPU support. --- src/device.rs | 10 ++++++++++ src/storage.rs | 15 +++++++++++++++ src/tensor.rs | 3 +++ 3 files changed, 28 insertions(+) diff --git a/src/device.rs b/src/device.rs index d7b724d1..c092a347 100644 --- a/src/device.rs +++ b/src/device.rs @@ -6,6 +6,7 @@ use crate::{ #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum Device { Cpu, + Cuda { gpu_id: usize }, } // TODO: Should we back the cpu implementation using the NdArray crate or similar? @@ -72,6 +73,9 @@ impl Device { }; Storage::Cpu(storage) } + Device::Cuda { gpu_id: _ } => { + todo!() + } } } @@ -91,12 +95,18 @@ impl Device { }; Storage::Cpu(storage) } + Device::Cuda { gpu_id: _ } => { + todo!() + } } } pub(crate) fn tensor(&self, array: A) -> Storage { match self { Device::Cpu => Storage::Cpu(array.to_cpu_storage()), + Device::Cuda { gpu_id: _ } => { + todo!() + } } } } diff --git a/src/storage.rs b/src/storage.rs index 463788d4..30161a2c 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -20,6 +20,7 @@ impl CpuStorage { #[derive(Debug, Clone)] pub enum Storage { Cpu(CpuStorage), + Cuda { gpu_id: usize }, // TODO: Actually add the storage. } trait UnaryOp { @@ -116,12 +117,14 @@ impl Storage { pub fn device(&self) -> Device { match self { Self::Cpu(_) => Device::Cpu, + Self::Cuda { gpu_id } => Device::Cuda { gpu_id: *gpu_id }, } } pub fn dtype(&self) -> DType { match self { Self::Cpu(storage) => storage.dtype(), + Self::Cuda { .. } => todo!(), } } @@ -168,6 +171,7 @@ impl Storage { Ok(Storage::Cpu(CpuStorage::F64(data))) } }, + Self::Cuda { .. } => todo!(), } } @@ -186,6 +190,7 @@ impl Storage { Ok(Storage::Cpu(CpuStorage::F64(data))) } }, + Self::Cuda { .. } => todo!(), } } @@ -232,6 +237,16 @@ impl Storage { }) } }, + (Self::Cuda { .. }, Self::Cuda { .. }) => todo!(), + (lhs, rhs) => { + // Should not happen because of the same device check above but we're defensive + // anyway. + Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device(), + rhs: rhs.device(), + op: B::NAME, + }) + } } } diff --git a/src/tensor.rs b/src/tensor.rs index b8fa738a..2d704a65 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -209,6 +209,7 @@ impl Tensor { let data = S::cpu_storage_as_slice(cpu_storage)?; Ok(data[0]) } + Storage::Cuda { .. } => todo!(), } } @@ -249,6 +250,7 @@ impl Tensor { let data = S::cpu_storage_as_slice(cpu_storage)?; Ok(self.strided_index().map(|i| data[i]).collect()) } + Storage::Cuda { .. } => todo!(), } } @@ -266,6 +268,7 @@ impl Tensor { assert!(src_index.next().is_none()); Ok(rows) } + Storage::Cuda { .. } => todo!(), } }