From 2bfe8f18ab43b92a426d1f4b61e9c2c3b16c7050 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 21 Jun 2023 18:11:56 +0100 Subject: [PATCH] Start adding support for cuda. --- Cargo.toml | 2 +- src/device.rs | 32 +++++++++++++++++++++++++------- src/error.rs | 9 ++++++--- src/lib.rs | 2 +- src/storage.rs | 12 ++++++------ 5 files changed, 39 insertions(+), 18 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 44d64f1a..5e2508b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,5 +22,5 @@ rand = "0.8.5" tokenizers = "0.13.3" [features] -default = [] +default = ["cuda"] cuda = ["dep:cudarc"] diff --git a/src/device.rs b/src/device.rs index c76cc301..bb3d8870 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,11 +1,19 @@ use crate::{CpuStorage, DType, Result, Shape, Storage}; +/// A `DeviceLocation` represents a physical device whereas multiple `Device` +/// can live on the same location (typically for cuda devices). #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub enum Device { +pub enum DeviceLocation { Cpu, Cuda { gpu_id: usize }, } +#[derive(Debug, Clone)] +pub enum Device { + Cpu, + Cuda(std::sync::Arc), +} + // TODO: Should we back the cpu implementation using the NdArray crate or similar? pub trait NdArray { fn shape(&self) -> Result; @@ -54,13 +62,22 @@ impl NdArray for &[[S; N]; } impl Device { + pub fn location(&self) -> DeviceLocation { + match self { + Self::Cpu => DeviceLocation::Cpu, + Self::Cuda(device) => DeviceLocation::Cuda { + gpu_id: device.ordinal(), + }, + } + } + pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result { match self { Device::Cpu => { let storage = Storage::Cpu(CpuStorage::ones_impl(shape, dtype)); Ok(storage) } - Device::Cuda { gpu_id: _ } => { + Device::Cuda(_) => { todo!() } } @@ -69,11 +86,12 @@ impl Device { pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result { match self { Device::Cpu => { - let storage = Storage::Cpu(CpuStorage::zeros_impl(shape, dtype)); - Ok(storage) + let storage = CpuStorage::zeros_impl(shape, dtype); + Ok(Storage::Cpu(storage)) } - Device::Cuda { gpu_id: _ } => { - todo!() + Device::Cuda(device) => { + let storage = device.alloc_zeros::(shape.elem_count())?; + Ok(Storage::Cuda(storage)) } } } @@ -84,7 +102,7 @@ impl Device { let storage = Storage::Cpu(array.to_cpu_storage()); Ok(storage) } - Device::Cuda { gpu_id: _ } => { + Device::Cuda(_) => { todo!() } } diff --git a/src/error.rs b/src/error.rs index 0114a86c..1990b67a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,4 @@ -use crate::{DType, Device, Shape}; +use crate::{DType, DeviceLocation, Shape}; /// Main library error type. #[derive(thiserror::Error, Debug)] @@ -15,8 +15,8 @@ pub enum Error { #[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] DeviceMismatchBinaryOp { - lhs: Device, - rhs: Device, + lhs: DeviceLocation, + rhs: DeviceLocation, op: &'static str, }, @@ -33,6 +33,9 @@ pub enum Error { got: usize, shape: Shape, }, + + #[error(transparent)] + Cudarc(#[from] cudarc::driver::DriverError), } pub type Result = std::result::Result; diff --git a/src/lib.rs b/src/lib.rs index 175d36ad..9462bfa0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,7 +9,7 @@ mod strided_index; mod tensor; pub use cpu_backend::CpuStorage; -pub use device::Device; +pub use device::{Device, DeviceLocation}; pub use dtype::{DType, WithDType}; pub use error::{Error, Result}; pub use shape::Shape; diff --git a/src/storage.rs b/src/storage.rs index 7083cc28..4c8ebcc3 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -3,7 +3,7 @@ use crate::{CpuStorage, DType, Device, Error, Result, Shape}; #[derive(Debug, Clone)] pub enum Storage { Cpu(CpuStorage), - Cuda { gpu_id: usize }, // TODO: Actually add the storage. + Cuda(cudarc::driver::CudaSlice), } pub(crate) trait UnaryOp { @@ -100,7 +100,7 @@ impl Storage { pub fn device(&self) -> Device { match self { Self::Cpu(_) => Device::Cpu, - Self::Cuda { gpu_id } => Device::Cuda { gpu_id: *gpu_id }, + Self::Cuda(slice) => Device::Cuda(slice.device()), } } @@ -112,8 +112,8 @@ impl Storage { } pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> { - let lhs = self.device(); - let rhs = rhs.device(); + let lhs = self.device().location(); + let rhs = rhs.device().location(); if lhs != rhs { Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }) } else { @@ -179,8 +179,8 @@ impl Storage { // Should not happen because of the same device check above but we're defensive // anyway. Err(Error::DeviceMismatchBinaryOp { - lhs: lhs.device(), - rhs: rhs.device(), + lhs: lhs.device().location(), + rhs: rhs.device().location(), op: B::NAME, }) }