Add some skeleton code for GPU support.

This commit is contained in:
laurent
2023-06-21 09:13:57 +01:00
parent f319583530
commit 8cde0c5478
3 changed files with 28 additions and 0 deletions

View File

@ -6,6 +6,7 @@ use crate::{
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum Device { pub enum Device {
Cpu, Cpu,
Cuda { gpu_id: usize },
} }
// TODO: Should we back the cpu implementation using the NdArray crate or similar? // TODO: Should we back the cpu implementation using the NdArray crate or similar?
@ -72,6 +73,9 @@ impl Device {
}; };
Storage::Cpu(storage) Storage::Cpu(storage)
} }
Device::Cuda { gpu_id: _ } => {
todo!()
}
} }
} }
@ -91,12 +95,18 @@ impl Device {
}; };
Storage::Cpu(storage) Storage::Cpu(storage)
} }
Device::Cuda { gpu_id: _ } => {
todo!()
}
} }
} }
pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Storage { pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Storage {
match self { match self {
Device::Cpu => Storage::Cpu(array.to_cpu_storage()), Device::Cpu => Storage::Cpu(array.to_cpu_storage()),
Device::Cuda { gpu_id: _ } => {
todo!()
}
} }
} }
} }

View File

@ -20,6 +20,7 @@ impl CpuStorage {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum Storage { pub enum Storage {
Cpu(CpuStorage), Cpu(CpuStorage),
Cuda { gpu_id: usize }, // TODO: Actually add the storage.
} }
trait UnaryOp { trait UnaryOp {
@ -116,12 +117,14 @@ impl Storage {
pub fn device(&self) -> Device { pub fn device(&self) -> Device {
match self { match self {
Self::Cpu(_) => Device::Cpu, Self::Cpu(_) => Device::Cpu,
Self::Cuda { gpu_id } => Device::Cuda { gpu_id: *gpu_id },
} }
} }
pub fn dtype(&self) -> DType { pub fn dtype(&self) -> DType {
match self { match self {
Self::Cpu(storage) => storage.dtype(), Self::Cpu(storage) => storage.dtype(),
Self::Cuda { .. } => todo!(),
} }
} }
@ -168,6 +171,7 @@ impl Storage {
Ok(Storage::Cpu(CpuStorage::F64(data))) Ok(Storage::Cpu(CpuStorage::F64(data)))
} }
}, },
Self::Cuda { .. } => todo!(),
} }
} }
@ -186,6 +190,7 @@ impl Storage {
Ok(Storage::Cpu(CpuStorage::F64(data))) Ok(Storage::Cpu(CpuStorage::F64(data)))
} }
}, },
Self::Cuda { .. } => todo!(),
} }
} }
@ -232,6 +237,16 @@ impl Storage {
}) })
} }
}, },
(Self::Cuda { .. }, Self::Cuda { .. }) => todo!(),
(lhs, rhs) => {
// Should not happen because of the same device check above but we're defensive
// anyway.
Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device(),
rhs: rhs.device(),
op: B::NAME,
})
}
} }
} }

View File

@ -209,6 +209,7 @@ impl Tensor {
let data = S::cpu_storage_as_slice(cpu_storage)?; let data = S::cpu_storage_as_slice(cpu_storage)?;
Ok(data[0]) Ok(data[0])
} }
Storage::Cuda { .. } => todo!(),
} }
} }
@ -249,6 +250,7 @@ impl Tensor {
let data = S::cpu_storage_as_slice(cpu_storage)?; let data = S::cpu_storage_as_slice(cpu_storage)?;
Ok(self.strided_index().map(|i| data[i]).collect()) Ok(self.strided_index().map(|i| data[i]).collect())
} }
Storage::Cuda { .. } => todo!(),
} }
} }
@ -266,6 +268,7 @@ impl Tensor {
assert!(src_index.next().is_none()); assert!(src_index.next().is_none());
Ok(rows) Ok(rows)
} }
Storage::Cuda { .. } => todo!(),
} }
} }