Start adding support for cuda.

This commit is contained in:
laurent
2023-06-21 18:11:56 +01:00
parent 7c317f9611
commit 2bfe8f18ab
5 changed files with 39 additions and 18 deletions

View File

@ -22,5 +22,5 @@ rand = "0.8.5"
tokenizers = "0.13.3" tokenizers = "0.13.3"
[features] [features]
default = [] default = ["cuda"]
cuda = ["dep:cudarc"] cuda = ["dep:cudarc"]

View File

@ -1,11 +1,19 @@
use crate::{CpuStorage, DType, Result, Shape, Storage}; use crate::{CpuStorage, DType, Result, Shape, Storage};
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
/// can live on the same location (typically for cuda devices).
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum Device { pub enum DeviceLocation {
Cpu, Cpu,
Cuda { gpu_id: usize }, Cuda { gpu_id: usize },
} }
#[derive(Debug, Clone)]
pub enum Device {
Cpu,
Cuda(std::sync::Arc<cudarc::driver::CudaDevice>),
}
// TODO: Should we back the cpu implementation using the NdArray crate or similar? // TODO: Should we back the cpu implementation using the NdArray crate or similar?
pub trait NdArray { pub trait NdArray {
fn shape(&self) -> Result<Shape>; fn shape(&self) -> Result<Shape>;
@ -54,13 +62,22 @@ impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N];
} }
impl Device { impl Device {
pub fn location(&self) -> DeviceLocation {
match self {
Self::Cpu => DeviceLocation::Cpu,
Self::Cuda(device) => DeviceLocation::Cuda {
gpu_id: device.ordinal(),
},
}
}
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> { pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self { match self {
Device::Cpu => { Device::Cpu => {
let storage = Storage::Cpu(CpuStorage::ones_impl(shape, dtype)); let storage = Storage::Cpu(CpuStorage::ones_impl(shape, dtype));
Ok(storage) Ok(storage)
} }
Device::Cuda { gpu_id: _ } => { Device::Cuda(_) => {
todo!() todo!()
} }
} }
@ -69,11 +86,12 @@ impl Device {
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> { pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self { match self {
Device::Cpu => { Device::Cpu => {
let storage = Storage::Cpu(CpuStorage::zeros_impl(shape, dtype)); let storage = CpuStorage::zeros_impl(shape, dtype);
Ok(storage) Ok(Storage::Cpu(storage))
} }
Device::Cuda { gpu_id: _ } => { Device::Cuda(device) => {
todo!() let storage = device.alloc_zeros::<f32>(shape.elem_count())?;
Ok(Storage::Cuda(storage))
} }
} }
} }
@ -84,7 +102,7 @@ impl Device {
let storage = Storage::Cpu(array.to_cpu_storage()); let storage = Storage::Cpu(array.to_cpu_storage());
Ok(storage) Ok(storage)
} }
Device::Cuda { gpu_id: _ } => { Device::Cuda(_) => {
todo!() todo!()
} }
} }

View File

@ -1,4 +1,4 @@
use crate::{DType, Device, Shape}; use crate::{DType, DeviceLocation, Shape};
/// Main library error type. /// Main library error type.
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
@ -15,8 +15,8 @@ pub enum Error {
#[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] #[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
DeviceMismatchBinaryOp { DeviceMismatchBinaryOp {
lhs: Device, lhs: DeviceLocation,
rhs: Device, rhs: DeviceLocation,
op: &'static str, op: &'static str,
}, },
@ -33,6 +33,9 @@ pub enum Error {
got: usize, got: usize,
shape: Shape, shape: Shape,
}, },
#[error(transparent)]
Cudarc(#[from] cudarc::driver::DriverError),
} }
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;

View File

@ -9,7 +9,7 @@ mod strided_index;
mod tensor; mod tensor;
pub use cpu_backend::CpuStorage; pub use cpu_backend::CpuStorage;
pub use device::Device; pub use device::{Device, DeviceLocation};
pub use dtype::{DType, WithDType}; pub use dtype::{DType, WithDType};
pub use error::{Error, Result}; pub use error::{Error, Result};
pub use shape::Shape; pub use shape::Shape;

View File

@ -3,7 +3,7 @@ use crate::{CpuStorage, DType, Device, Error, Result, Shape};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum Storage { pub enum Storage {
Cpu(CpuStorage), Cpu(CpuStorage),
Cuda { gpu_id: usize }, // TODO: Actually add the storage. Cuda(cudarc::driver::CudaSlice<f32>),
} }
pub(crate) trait UnaryOp { pub(crate) trait UnaryOp {
@ -100,7 +100,7 @@ impl Storage {
pub fn device(&self) -> Device { pub fn device(&self) -> Device {
match self { match self {
Self::Cpu(_) => Device::Cpu, Self::Cpu(_) => Device::Cpu,
Self::Cuda { gpu_id } => Device::Cuda { gpu_id: *gpu_id }, Self::Cuda(slice) => Device::Cuda(slice.device()),
} }
} }
@ -112,8 +112,8 @@ impl Storage {
} }
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> { pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
let lhs = self.device(); let lhs = self.device().location();
let rhs = rhs.device(); let rhs = rhs.device().location();
if lhs != rhs { if lhs != rhs {
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }) Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op })
} else { } else {
@ -179,8 +179,8 @@ impl Storage {
// Should not happen because of the same device check above but we're defensive // Should not happen because of the same device check above but we're defensive
// anyway. // anyway.
Err(Error::DeviceMismatchBinaryOp { Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device(), lhs: lhs.device().location(),
rhs: rhs.device(), rhs: rhs.device().location(),
op: B::NAME, op: B::NAME,
}) })
} }