Start adding support for cuda.

This commit is contained in:
laurent
2023-06-21 18:11:56 +01:00
parent 7c317f9611
commit 2bfe8f18ab
5 changed files with 39 additions and 18 deletions

View File

@ -3,7 +3,7 @@ use crate::{CpuStorage, DType, Device, Error, Result, Shape};
#[derive(Debug, Clone)]
pub enum Storage {
Cpu(CpuStorage),
Cuda { gpu_id: usize }, // TODO: Actually add the storage.
Cuda(cudarc::driver::CudaSlice<f32>),
}
pub(crate) trait UnaryOp {
@ -100,7 +100,7 @@ impl Storage {
pub fn device(&self) -> Device {
match self {
Self::Cpu(_) => Device::Cpu,
Self::Cuda { gpu_id } => Device::Cuda { gpu_id: *gpu_id },
Self::Cuda(slice) => Device::Cuda(slice.device()),
}
}
@ -112,8 +112,8 @@ impl Storage {
}
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
let lhs = self.device();
let rhs = rhs.device();
let lhs = self.device().location();
let rhs = rhs.device().location();
if lhs != rhs {
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op })
} else {
@ -179,8 +179,8 @@ impl Storage {
// Should not happen because of the same device check above but we're defensive
// anyway.
Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device(),
rhs: rhs.device(),
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: B::NAME,
})
}