mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Start adding support for cuda.
This commit is contained in:
@ -22,5 +22,5 @@ rand = "0.8.5"
|
|||||||
tokenizers = "0.13.3"
|
tokenizers = "0.13.3"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = ["cuda"]
|
||||||
cuda = ["dep:cudarc"]
|
cuda = ["dep:cudarc"]
|
||||||
|
@ -1,11 +1,19 @@
|
|||||||
use crate::{CpuStorage, DType, Result, Shape, Storage};
|
use crate::{CpuStorage, DType, Result, Shape, Storage};
|
||||||
|
|
||||||
|
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
|
||||||
|
/// can live on the same location (typically for cuda devices).
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
pub enum Device {
|
pub enum DeviceLocation {
|
||||||
Cpu,
|
Cpu,
|
||||||
Cuda { gpu_id: usize },
|
Cuda { gpu_id: usize },
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum Device {
|
||||||
|
Cpu,
|
||||||
|
Cuda(std::sync::Arc<cudarc::driver::CudaDevice>),
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
|
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
|
||||||
pub trait NdArray {
|
pub trait NdArray {
|
||||||
fn shape(&self) -> Result<Shape>;
|
fn shape(&self) -> Result<Shape>;
|
||||||
@ -54,13 +62,22 @@ impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N];
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Device {
|
impl Device {
|
||||||
|
pub fn location(&self) -> DeviceLocation {
|
||||||
|
match self {
|
||||||
|
Self::Cpu => DeviceLocation::Cpu,
|
||||||
|
Self::Cuda(device) => DeviceLocation::Cuda {
|
||||||
|
gpu_id: device.ordinal(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||||
match self {
|
match self {
|
||||||
Device::Cpu => {
|
Device::Cpu => {
|
||||||
let storage = Storage::Cpu(CpuStorage::ones_impl(shape, dtype));
|
let storage = Storage::Cpu(CpuStorage::ones_impl(shape, dtype));
|
||||||
Ok(storage)
|
Ok(storage)
|
||||||
}
|
}
|
||||||
Device::Cuda { gpu_id: _ } => {
|
Device::Cuda(_) => {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -69,11 +86,12 @@ impl Device {
|
|||||||
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||||
match self {
|
match self {
|
||||||
Device::Cpu => {
|
Device::Cpu => {
|
||||||
let storage = Storage::Cpu(CpuStorage::zeros_impl(shape, dtype));
|
let storage = CpuStorage::zeros_impl(shape, dtype);
|
||||||
Ok(storage)
|
Ok(Storage::Cpu(storage))
|
||||||
}
|
}
|
||||||
Device::Cuda { gpu_id: _ } => {
|
Device::Cuda(device) => {
|
||||||
todo!()
|
let storage = device.alloc_zeros::<f32>(shape.elem_count())?;
|
||||||
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -84,7 +102,7 @@ impl Device {
|
|||||||
let storage = Storage::Cpu(array.to_cpu_storage());
|
let storage = Storage::Cpu(array.to_cpu_storage());
|
||||||
Ok(storage)
|
Ok(storage)
|
||||||
}
|
}
|
||||||
Device::Cuda { gpu_id: _ } => {
|
Device::Cuda(_) => {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::{DType, Device, Shape};
|
use crate::{DType, DeviceLocation, Shape};
|
||||||
|
|
||||||
/// Main library error type.
|
/// Main library error type.
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
@ -15,8 +15,8 @@ pub enum Error {
|
|||||||
|
|
||||||
#[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
#[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||||
DeviceMismatchBinaryOp {
|
DeviceMismatchBinaryOp {
|
||||||
lhs: Device,
|
lhs: DeviceLocation,
|
||||||
rhs: Device,
|
rhs: DeviceLocation,
|
||||||
op: &'static str,
|
op: &'static str,
|
||||||
},
|
},
|
||||||
|
|
||||||
@ -33,6 +33,9 @@ pub enum Error {
|
|||||||
got: usize,
|
got: usize,
|
||||||
shape: Shape,
|
shape: Shape,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Cudarc(#[from] cudarc::driver::DriverError),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, Error>;
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
@ -9,7 +9,7 @@ mod strided_index;
|
|||||||
mod tensor;
|
mod tensor;
|
||||||
|
|
||||||
pub use cpu_backend::CpuStorage;
|
pub use cpu_backend::CpuStorage;
|
||||||
pub use device::Device;
|
pub use device::{Device, DeviceLocation};
|
||||||
pub use dtype::{DType, WithDType};
|
pub use dtype::{DType, WithDType};
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
pub use shape::Shape;
|
pub use shape::Shape;
|
||||||
|
@ -3,7 +3,7 @@ use crate::{CpuStorage, DType, Device, Error, Result, Shape};
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum Storage {
|
pub enum Storage {
|
||||||
Cpu(CpuStorage),
|
Cpu(CpuStorage),
|
||||||
Cuda { gpu_id: usize }, // TODO: Actually add the storage.
|
Cuda(cudarc::driver::CudaSlice<f32>),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) trait UnaryOp {
|
pub(crate) trait UnaryOp {
|
||||||
@ -100,7 +100,7 @@ impl Storage {
|
|||||||
pub fn device(&self) -> Device {
|
pub fn device(&self) -> Device {
|
||||||
match self {
|
match self {
|
||||||
Self::Cpu(_) => Device::Cpu,
|
Self::Cpu(_) => Device::Cpu,
|
||||||
Self::Cuda { gpu_id } => Device::Cuda { gpu_id: *gpu_id },
|
Self::Cuda(slice) => Device::Cuda(slice.device()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,8 +112,8 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
||||||
let lhs = self.device();
|
let lhs = self.device().location();
|
||||||
let rhs = rhs.device();
|
let rhs = rhs.device().location();
|
||||||
if lhs != rhs {
|
if lhs != rhs {
|
||||||
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op })
|
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op })
|
||||||
} else {
|
} else {
|
||||||
@ -179,8 +179,8 @@ impl Storage {
|
|||||||
// Should not happen because of the same device check above but we're defensive
|
// Should not happen because of the same device check above but we're defensive
|
||||||
// anyway.
|
// anyway.
|
||||||
Err(Error::DeviceMismatchBinaryOp {
|
Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device(),
|
rhs: rhs.device().location(),
|
||||||
op: B::NAME,
|
op: B::NAME,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user