mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add some skeleton code for GPU support.
This commit is contained in:
@ -6,6 +6,7 @@ use crate::{
|
|||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
pub enum Device {
|
pub enum Device {
|
||||||
Cpu,
|
Cpu,
|
||||||
|
Cuda { gpu_id: usize },
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
|
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
|
||||||
@ -72,6 +73,9 @@ impl Device {
|
|||||||
};
|
};
|
||||||
Storage::Cpu(storage)
|
Storage::Cpu(storage)
|
||||||
}
|
}
|
||||||
|
Device::Cuda { gpu_id: _ } => {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,12 +95,18 @@ impl Device {
|
|||||||
};
|
};
|
||||||
Storage::Cpu(storage)
|
Storage::Cpu(storage)
|
||||||
}
|
}
|
||||||
|
Device::Cuda { gpu_id: _ } => {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Storage {
|
pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Storage {
|
||||||
match self {
|
match self {
|
||||||
Device::Cpu => Storage::Cpu(array.to_cpu_storage()),
|
Device::Cpu => Storage::Cpu(array.to_cpu_storage()),
|
||||||
|
Device::Cuda { gpu_id: _ } => {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -20,6 +20,7 @@ impl CpuStorage {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum Storage {
|
pub enum Storage {
|
||||||
Cpu(CpuStorage),
|
Cpu(CpuStorage),
|
||||||
|
Cuda { gpu_id: usize }, // TODO: Actually add the storage.
|
||||||
}
|
}
|
||||||
|
|
||||||
trait UnaryOp {
|
trait UnaryOp {
|
||||||
@ -116,12 +117,14 @@ impl Storage {
|
|||||||
pub fn device(&self) -> Device {
|
pub fn device(&self) -> Device {
|
||||||
match self {
|
match self {
|
||||||
Self::Cpu(_) => Device::Cpu,
|
Self::Cpu(_) => Device::Cpu,
|
||||||
|
Self::Cuda { gpu_id } => Device::Cuda { gpu_id: *gpu_id },
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
match self {
|
match self {
|
||||||
Self::Cpu(storage) => storage.dtype(),
|
Self::Cpu(storage) => storage.dtype(),
|
||||||
|
Self::Cuda { .. } => todo!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,6 +171,7 @@ impl Storage {
|
|||||||
Ok(Storage::Cpu(CpuStorage::F64(data)))
|
Ok(Storage::Cpu(CpuStorage::F64(data)))
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
Self::Cuda { .. } => todo!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -186,6 +190,7 @@ impl Storage {
|
|||||||
Ok(Storage::Cpu(CpuStorage::F64(data)))
|
Ok(Storage::Cpu(CpuStorage::F64(data)))
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
Self::Cuda { .. } => todo!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -232,6 +237,16 @@ impl Storage {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
(Self::Cuda { .. }, Self::Cuda { .. }) => todo!(),
|
||||||
|
(lhs, rhs) => {
|
||||||
|
// Should not happen because of the same device check above but we're defensive
|
||||||
|
// anyway.
|
||||||
|
Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: lhs.device(),
|
||||||
|
rhs: rhs.device(),
|
||||||
|
op: B::NAME,
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -209,6 +209,7 @@ impl Tensor {
|
|||||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||||
Ok(data[0])
|
Ok(data[0])
|
||||||
}
|
}
|
||||||
|
Storage::Cuda { .. } => todo!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -249,6 +250,7 @@ impl Tensor {
|
|||||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||||
Ok(self.strided_index().map(|i| data[i]).collect())
|
Ok(self.strided_index().map(|i| data[i]).collect())
|
||||||
}
|
}
|
||||||
|
Storage::Cuda { .. } => todo!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -266,6 +268,7 @@ impl Tensor {
|
|||||||
assert!(src_index.next().is_none());
|
assert!(src_index.next().is_none());
|
||||||
Ok(rows)
|
Ok(rows)
|
||||||
}
|
}
|
||||||
|
Storage::Cuda { .. } => todo!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user