mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Start adding support for cuda.
This commit is contained in:
@ -1,11 +1,19 @@
|
||||
use crate::{CpuStorage, DType, Result, Shape, Storage};
|
||||
|
||||
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
|
||||
/// can live on the same location (typically for cuda devices).
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum Device {
|
||||
pub enum DeviceLocation {
|
||||
Cpu,
|
||||
Cuda { gpu_id: usize },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Device {
|
||||
Cpu,
|
||||
Cuda(std::sync::Arc<cudarc::driver::CudaDevice>),
|
||||
}
|
||||
|
||||
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
|
||||
pub trait NdArray {
|
||||
fn shape(&self) -> Result<Shape>;
|
||||
@ -54,13 +62,22 @@ impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N];
|
||||
}
|
||||
|
||||
impl Device {
|
||||
pub fn location(&self) -> DeviceLocation {
|
||||
match self {
|
||||
Self::Cpu => DeviceLocation::Cpu,
|
||||
Self::Cuda(device) => DeviceLocation::Cuda {
|
||||
gpu_id: device.ordinal(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = Storage::Cpu(CpuStorage::ones_impl(shape, dtype));
|
||||
Ok(storage)
|
||||
}
|
||||
Device::Cuda { gpu_id: _ } => {
|
||||
Device::Cuda(_) => {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
@ -69,11 +86,12 @@ impl Device {
|
||||
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = Storage::Cpu(CpuStorage::zeros_impl(shape, dtype));
|
||||
Ok(storage)
|
||||
let storage = CpuStorage::zeros_impl(shape, dtype);
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda { gpu_id: _ } => {
|
||||
todo!()
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.alloc_zeros::<f32>(shape.elem_count())?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -84,7 +102,7 @@ impl Device {
|
||||
let storage = Storage::Cpu(array.to_cpu_storage());
|
||||
Ok(storage)
|
||||
}
|
||||
Device::Cuda { gpu_id: _ } => {
|
||||
Device::Cuda(_) => {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user